Skip to content

Conversation

@lihaoyang-amd
Copy link
Contributor

@lihaoyang-amd lihaoyang-amd commented Jun 17, 2025

Just For ROCM
1.Add quickreduce alternative to custom allreduce and rccl. (In case of large amount of data, custom quick reduce is used instead of custom allreduce and rccl, you can refer to the results of kernel tests.)

2.The collective is only enabled on AMD, MI300, for fp16/bf16 inputs and when custom allreduce is enabled. The kernels support full precision and quantized int8, int6, int4 (symmetric quantization with group size 32) all reduce collective quantization algorithm.

3.The quickreduce can be enabled by setting VLLM_ROCM_QUICK_REDUCE_QUANTIZATION=[NONE|FP|INT8|INT6|INT4] env variable. quickreduce supports int8, int6, int4 quantization. NONE means turn off quick allreduce.

4.PR supports fp16 and bf16 kernels but given the lack of intrinsics of bf16 math operations, bf16 kernels performance is worse (see kernel benchmark results below), so by default we convert bf16 all reduce input to fp16. To disable this behavior, set the environment variable VLLM_ROCM_QR_CAST_BF16_TO_FP16=0.

5.As long as quickreduce only get the performance benefits at middle/higher input sizes (see kernel benchmarks), vllm keeps using custom allreduce for small inputs. The lower limit for enabling quickreduce is chosen based on experimental results

6.The default maximum input size of quickreduce is 2GB, for users with limited video memory, the preset buffer may be a bit too large, you can adjust the value in MB by VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB.

Kernels benchmark

Baseline is custom allreduce when the data size is less than 16MB and rccl when the data size is greater than 16MB

TP=2

msg size baseline QR FP QR int8 QR int6 QR int4 QR FP bf16 QR int8 bf16 QR int6 bf16 QR int4 bf16
2.0KB 7.26 10.75 22.05 23.79 19.04 13.11 92.69 94.29 92.05
32.0KB 7.40 10.88 22.09 24.18 19.16 13.23 93.39 95.16 92.63
256.0KB 11.87 14.91 23.89 25.54 20.35 17.56 95.58 96.54 94.54
512.0KB 18.08 19.93 25.62 25.88 21.17 22.09 95.30 96.18 94.66
1.0MB 30.07 30.57 31.19 29.81 23.57 32.55 96.10 97.14 95.30
2.0MB 53.76 51.96 43.50 39.25 30.03 54.08 99.33 99.15 96.80
4.0MB 102.19 97.90 66.63 57.63 42.74 98.62 125.23 116.04 105.32
8.0MB 199.01 190.19 115.36 95.66 68.28 191.92 178.99 158.96 138.31
16.0MB 391.43 378.94 219.60 174.70 125.08 378.34 287.39 244.03 201.06
32.0MB 892.99 739.28 425.61 339.69 243.53 741.36 517.28 437.13 351.26
64.0MB 1465.11 1466.43 828.58 650.82 464.56 1470.50 953.49 801.14 635.16
128.0MB 2912.45 2917.14 1634.59 1277.78 898.42 2935.37 1777.66 1450.88 1153.58
256.0MB 5927.01 5822.34 3252.23 2534.88 1772.40 5866.49 3433.40 2804.97 2192.04
512.0MB 11575.91 11639.26 6491.12 5058.28 3522.78 11727.95 6809.77 5505.87 4262.04
1GB 23223.61 23255.00 12971.35 10106.85 7023.81 23435.05 13586.55 10945.95 8392.96
2GB 45968.99 46101.59 26021.00 20228.00 14164.00 47084.30 27227.85 21836.00 16624.25

TP=4

msg size baseline QR FP QR int8 QR int6 QR int4 QR FP bf16 QR int8 bf16 QR int6 bf16 QR int4 bf16
2.0KB 7.18 12.67 23.89 25.62 21.33 16.14 79.68 93.95 84.15
16.0KB 7.26 12.79 23.69 25.84 21.29 16.12 80.08 93.89 84.31
32.0KB 7.58 12.87 23.81 25.94 21.27 16.20 80.72 94.33 84.87
256.0KB 14.39 15.88 26.78 27.48 22.85 19.18 82.15 94.74 86.22
512.0KB 14.65 17.76 27.12 27.98 23.65 21.01 81.73 94.75 86.30
1.0MB 22.91 22.49 29.43 29.85 24.72 26.34 82.55 95.92 81.59
2.0MB 36.64 36.08 40.97 39.27 28.49 38.45 86.48 97.98 82.35
4.0MB 63.86 63.28 66.95 54.81 37.49 72.12 109.41 114.94 90.81
8.0MB 118.31 126.69 126.43 99.41 64.23 137.45 168.64 157.50 116.16
16.0MB 230.48 237.08 204.46 167.08 109.77 237.42 290.08 256.74 183.08
32.0MB 389.03 439.12 390.55 307.40 217.65 441.30 470.49 440.30 304.83
64.0MB 1017.56 825.53 654.79 509.82 364.77 837.36 803.16 731.90 522.51
128.0MB 1910.37 1587.00 1090.06 848.89 596.27 1606.67 1307.73 1220.87 886.31
256.0MB 3542.03 3082.80 1970.84 1535.23 1078.49 3135.44 2281.91 2180.19 1613.20
512.0MB 6560.81 6098.23 3735.02 2892.65 2015.72 6185.95 4282.83 4096.10 3154.83
1GB 12582.56 12105.15 7275.68 5618.14 3895.45 12288.60 8317.19 7991.48 6231.14
2GB 24453.95 24570.59 14636.20 11087.40 7685.00 24529.95 16488.65 15956.70 12265.90

TP 8

size baseline QR FP QR Q8 QR Q6 QR Q4 QR FP bf QR Q8 bf QR Q6 bf QR Q4 bf
4k 14.07 14.47 28.42 30.65 28.31 19.44 79.46 86.40 78.88
8k 12.99 15.52 30.99 31.88 27.69 19.18 80.49 87.08 79.59
16k 12.38 15.13 29.89 31.78 27.46 18.99 81.10 87.71 80.10
32k 12.29 15.38 29.79 31.79 28.73 19.21 81.89 83.33 81.52
64k 12.39 22.50 31.53 31.81 28.92 19.87 82.59 84.53 81.33
128k 13.75 20.61 31.21 33.35 40.00 22.22 83.23 87.84 82.24
256k 17.13 21.05 37.70 47.96 45.16 25.07 82.45 87.31 84.08
512k 17.51 22.04 32.41 34.73 37.60 25.06 83.44 86.34 85.00
1M 24.92 23.44 32.76 34.57 31.06 27.88 85.35 87.96 83.70
2M 36.06 30.27 42.37 41.19 30.98 34.09 90.69 90.69 88.53
4M 50.92 46.09 47.26 43.68 34.15 49.57 113.57 93.56 92.43
8M 83.69 84.15 81.73 66.15 47.80 85.35 125.36 122.75 115.00
16M 149.02 148.81 188.60 137.84 83.37 144.57 246.64 180.91 152.14
32M 239.78 259.39 345.67 255.08 178.61 253.05 402.91 336.31 265.64
64M 425.57 463.76 520.00 399.89 281.86 466.72 690.57 569.52 463.40
128M 793.60 861.38 787.08 618.50 444.63 863.46 1050.74 899.01 770.36
256M 1537.40 1652.66 1307.59 1101.47 862.49 1652.09 1788.94 1661.50 1516.37
512M 3028.06 3230.72 2333.97 2036.01 1678.64 3215.32 3480.68 3286.91 3064.64
1GB 5993.59 6392.41 4396.49 3932.57 3325.08 6366.13 6726.18 6470.78 6167.29
2GB 11944.33 12770.91 8645.80 7864.18 6786.01 12793.68 13185.73 12846.10 12244.02

E2E server benchmark float16

Server: VLLM_USE_V1=1 VLLM_USE_TRITON_FLASH_ATTN=0 vllm serve meta-llama/Llama-3.1-70B-Instruct --block_size=32 --disable-log-requests --no-enable-prefix-caching -tp $tp --dtype float16
Client: python benchmarks/benchmark_serving.py --model meta-llama/Llama-3.1-70B-Instruct --dataset-name sonnet --dataset-path benchmarks/sonnet.txt --num-prompts 500--request-rate 10 --ignore-eos

TP=8

TTFT, ms Speedup TPOT, ms Speedup
baseline 145 1 51 1x
QR, fp 106 1.37x 38 1.34x
QR, int4 90 1.61x 35 1.46x

TP=4

TTFT, ms Speedup TPOT, ms Speedup
baseline 316 1 89 1x
QR, fp 280 1.09x 89 1x
QR, int8 270 1.17x 85 1.05x
QR, int6 222 1.42x 70 1.27x
QR, int4 138 2.2x 45 2x

E2E server benchmark bfloat16

Server: VLLM_USE_V1=1 VLLM_USE_TRITON_FLASH_ATTN=0 vllm serve model_path --block_size=32 --disable-log-requests --no-enable-prefix-caching -tp $tp --dtype auto
Client: python benchmarks/benchmark_serving.py --model model_path --dataset-name sonnet --dataset-path benchmarks/sonnet.txt --num-prompts 500--request-rate 10 --ignore-eos

use VLLM_ROCM_QR_CAST_BF16_TO_FP16=1

Qwen2.5-72B Q_lelvel TTFT, ms Speedup TPOT, ms Speedup
tp8 baseline 131.03 1 45.21 1
  Q8 127.28 1.029462602 44.13 1.024473148
  Q6 113.31 1.156385138 42.75 1.05754386
  Q4 108.38 1.208986898 38.03 1.188798317
tp4 baseline 294.69 1 90.12 1
  Q8 272.48 1.08151057 85.11 1.058864998
  Q6 235.86 1.249427627 74.72 1.206102784
  Q4 204.9 1.438213763 64.23 1.403082672
tp2 baseline 4484.25 1 310.35 1
  Q8 1728.56 2.594211367 277.54 1.118217194
  Q6 1372.67 3.266808483 265.78 1.167695086
  Q4 1083.78 4.137601727 247.3 1.254953498
Llama-3.1-70B Q_lelvel TTFT, ms Speedup TPOT, ms Speedup
tp8 baseline 132.17 1 46.81 1
  Q8 121.23 1.090241689 45.1 1.037915743
  Q6 109.11 1.211346348 43.2 1.083564815
  Q4 103.39 1.278363478 41.54 1.126865672
tp4 baseline 264.74 1 81.82 1
  Q8 248.99 1.063255552 77.74 1.052482634
  Q6 215.11 1.230719167 68.14 1.200763135
  Q4 189.63 1.396087117 59.18 1.382561676
tp2 baseline 3534.81 1 303.54 1
  Q8 1422.79 2.484421454 269.87 1.124763775
  Q6 1176.12 3.005484134 257.88 1.177059097
  Q4 954.81 3.702108273 238.16 1.27452133
Llama-3.1-8B Q_lelvel TTFT, ms Speedup TPOT, ms Speedup
tp8 baseline 29.75 1 8.72 1
  Q8 27.56 1.07946299 8.56 1.018691589
  Q6 25.65 1.159844055 8.48 1.028301887
  Q4 25.46 1.168499607 8.42 1.035629454
tp4 baseline 33.48 1 9.93 1
  Q8 32.64 1.025735294 9.71 1.022657055
  Q6 27.04 1.23816568 9.24 1.074675325
  Q4 26.77 1.250653717 9.25 1.073513514
tp2 baseline 38.88 1 11.11 1
  Q8 35.56 1.09336333 10.22 1.087084149
  Q6 34.08 1.14084507 9.99 1.112112112
  Q4 33.2 1.171084337 9.93 1.118831823

Evaluation results
on MMLU benchmark (LLaMa 3.1 70B, TP=8)

MMLU, STEM MMLU, human MMLU, social
baseline 0.76 0.81 0.88
QR, int4 0.76 0.81 0.88

on GSM8K(use bf2fp16 by envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=1)

Model QR quant level TP GSM8K
Meta-Llama-3.1-70B-Instruct NONE(baseline) 2 0.94
INT8 2 0.932
INT6 2 0.956
INT4 2 0.956
NONE(baseline) 4 0.94
INT8 4 0.932
INT6 4 0.936
INT4 4 0.916
NONE(baseline) 8 0.944
INT8 8 0.944
INT6 8 0.94
INT4 8 0.924
Meta-Llama-3.1-8B-Instruct NONE(baseline) 2 0.784
INT8 2 0.8
INT6 2 0.776
INT4 2 0.8
NONE(baseline) 4 0.764
INT8 4 0.8
INT6 4 0.776
INT4 4 0.788
NONE(baseline) 8 0.8
INT8 8 0.796
INT6 8 0.78
INT4 8 0.8
Qwen2.5-72B-Instruct NONE(baseline) 2 0.904
INT8 2 0.908
INT6 2 0.924
INT4 2 0.896
NONE(baseline) 4 0.916
INT8 4 0.92
INT6 4 0.908
INT4 4 0.88
NONE(baseline) 8 0.924
INT8 8 0.92
INT6 8 0.904
INT4 8 0.9

The initial PR was proposed by @ilmarkov , and we collaborated on it.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify
Copy link

mergify bot commented Jun 17, 2025

⚠️ The sha of the head commit of this PR conflicts with #16804. Mergify cannot evaluate rules on this PR. ⚠️

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @lihaoyang-amd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances VLLM's distributed communication capabilities on AMD GPUs by introducing a highly optimized 'Quick All-Reduce' feature. It integrates a specialized quickreduce library, providing support for various quantization levels to accelerate all-reduce operations. The changes include adding low-level C++/HIP kernels, exposing them through Python bindings, and implementing intelligent dispatch logic to automatically select the most performant communication strategy based on tensor properties and system configuration, aiming to improve performance for large tensor all-reduces.

Highlights

  • New Feature: Quick All-Reduce: Introduced a 'Quick All-Reduce' feature specifically optimized for AMD GPUs (ROCm), leveraging the quickreduce library for enhanced distributed communication performance.
  • Quantization Support: The new quick all-reduce supports various quantization levels, including FP16, INT8, INT6, and INT4, allowing for flexible performance tuning based on precision requirements.
  • Intelligent Dispatch Logic: Integrated the quick all-reduce into VLLM's existing custom all-reduce framework, implementing logic to dynamically select between the standard custom all-reduce and the new quick all-reduce based on message size, data type, and the distributed world size.
  • Low-Level Kernel Implementation: Added new C++/HIP kernels and Python bindings for the quick all-reduce operations, including specialized code for efficient IPC memory handling and GPU-specific intrinsics for AMD's CDNA architecture.
  • Configurable Behavior: Introduced new environment variables (VLLM_ROCM_QR_QUANT_REGIME and VLLM_ROCM_QR_CAST_BF16_TO_FP16) to allow users to configure the quick all-reduce's quantization level and bfloat16 casting behavior.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a 'Quick All Reduce' feature, primarily targeting ROCm/HIP environments. It adds new C++/HIP kernels for an accelerated all-reduce operation, including support for quantization (FP16, INT8, INT6, INT4). The changes include new source files for the QuickReduce implementation, CMake build system updates, PyTorch C++ bindings, Python wrappers, and integration into the CustomAllreduce class. Key areas of feedback include a critical bug in a quantization codec, a potential runtime error in Python logic, and a suggestion for type clarity in C++ header declarations.

@mergify
Copy link

mergify bot commented Jun 17, 2025

⚠️ The sha of the head commit of this PR conflicts with #16804. Mergify cannot evaluate rules on this PR. ⚠️

@mergify mergify bot added the ci/build label Jun 17, 2025
@mergify
Copy link

mergify bot commented Jun 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lihaoyang-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 17, 2025
@lihaoyang-amd lihaoyang-amd force-pushed the lhy/add_quick_all_reduce branch from e9fff8c to f314fe4 Compare June 17, 2025 16:48
@mergify
Copy link

mergify bot commented Jun 17, 2025

⚠️ The sha of the head commit of this PR conflicts with #16804. Mergify cannot evaluate rules on this PR. ⚠️

@lihaoyang-amd lihaoyang-amd marked this pull request as ready for review June 18, 2025 09:42
@lihaoyang-amd lihaoyang-amd changed the title [feature] add quick all reduce [Feature] add quick all reduce Jun 18, 2025
@lihaoyang-amd lihaoyang-amd force-pushed the lhy/add_quick_all_reduce branch from f314fe4 to 381009f Compare June 18, 2025 12:19
@mergify mergify bot removed the needs-rebase label Jun 18, 2025
@lihaoyang-amd lihaoyang-amd force-pushed the lhy/add_quick_all_reduce branch from 381009f to d5832a8 Compare June 19, 2025 04:25
@mergify mergify bot added the qwen Related to Qwen models label Jun 19, 2025
@lihaoyang-amd lihaoyang-amd force-pushed the lhy/add_quick_all_reduce branch from 3b81d4d to 811be44 Compare June 19, 2025 07:55
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 20, 2025
@DarkLight1337
Copy link
Member

Can you merge from main to see if the CI failures are resolved?

@mergify
Copy link

mergify bot commented Jun 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lihaoyang-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 20, 2025
@lihaoyang-amd lihaoyang-amd force-pushed the lhy/add_quick_all_reduce branch from aa4f696 to 51b9e6c Compare June 20, 2025 15:00
@mergify mergify bot removed the needs-rebase label Jun 20, 2025
@lihaoyang-amd lihaoyang-amd force-pushed the lhy/add_quick_all_reduce branch from 2f23c69 to 0209b2e Compare June 23, 2025 03:52
@lihaoyang-amd lihaoyang-amd force-pushed the lhy/add_quick_all_reduce branch 2 times, most recently from 1552e5a to c8c63dd Compare June 24, 2025 11:12
@lihaoyang-amd lihaoyang-amd force-pushed the lhy/add_quick_all_reduce branch from e7e84da to 7b7822c Compare June 24, 2025 16:48
lihaoyang-amd and others added 6 commits June 26, 2025 00:34
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good and great work on this PR! I left a few more comments, but I think they should be easy to address, thx!

Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) June 27, 2025 01:16
@vllm-bot vllm-bot merged commit 0740e29 into vllm-project:main Jun 27, 2025
96 of 100 checks passed
@lihaoyang-amd lihaoyang-amd changed the title [Feature] add quick all reduce [Feature][Rocm] add quick all reduce for rocm Jun 27, 2025
@mergify mergify bot added the rocm Related to AMD ROCm label Jun 27, 2025
@fxmarty-amd
Copy link
Contributor

@lihaoyang-amd do you know why

[2025-06-27T07:36:58Z] �[36m�[1m=========================== short test summary info ============================�[0m
[2025-06-27T07:36:58Z] �[31mFAILED�[0m quantization/test_fp8.py::�[1mtest_scaled_fp8_quant[dtype0]�[0m - AssertionError
[2025-06-27T07:36:58Z] �[31mFAILED�[0m quantization/test_fp8.py::�[1mtest_scaled_fp8_quant[dtype1]�[0m - AssertionError

are failing? Was it already failing?

@lihaoyang-amd
Copy link
Contributor Author

@lihaoyang-amd do you know why

[2025-06-27T07:36:58Z] �[36m�[1m=========================== short test summary info ============================�[0m
[2025-06-27T07:36:58Z] �[31mFAILED�[0m quantization/test_fp8.py::�[1mtest_scaled_fp8_quant[dtype0]�[0m - AssertionError
[2025-06-27T07:36:58Z] �[31mFAILED�[0m quantization/test_fp8.py::�[1mtest_scaled_fp8_quant[dtype1]�[0m - AssertionError

are failing? Was it already failing?

Yes, you can refer to some pr that was merged before this one in commit history.

@fxmarty-amd
Copy link
Contributor

@lihaoyang-amd Thanks as I found out in #17888 (comment) it has indeed been failing for at least a few commits/days earlier (see #20045 and buildkite.com/vllm/ci/builds/22676#0197a8d1-7fc7-4a19-bb6d-c1664f589dc9)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants