Skip to content

Commit

Permalink
Merge pull request vllm-project#1 from gnovack/aoyu/neuron-v1
Browse files Browse the repository at this point in the history
rebase to vllm main branch with nki_flash_attn
  • Loading branch information
gnovack authored Jan 28, 2025
2 parents 23afd2e + 14367f0 commit a330e7f
Show file tree
Hide file tree
Showing 158 changed files with 5,395 additions and 1,656 deletions.
2 changes: 1 addition & 1 deletion .buildkite/run-neuron-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
--name "${container_name}" \
${image_name} \
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py"
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys"
11 changes: 10 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,16 @@ steps:
- vllm/
- tests/v1
commands:
- VLLM_USE_V1=1 pytest -v -s v1
# split the test to avoid interference
- VLLM_USE_V1=1 pytest -v -s v1/core
- VLLM_USE_V1=1 pytest -v -s v1/engine
- VLLM_USE_V1=1 pytest -v -s v1/sample
- VLLM_USE_V1=1 pytest -v -s v1/worker
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
# TODO: accuracy does not match, whether setting
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
- VLLM_USE_V1=1 pytest -v -s v1/e2e

- label: Examples Test # 25min
working_dir: "/vllm-workspace/examples"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
.ipynb

# IPython
profile_default/
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ default_stages:
- manual # Run in CI
repos:
- repo: https://github.com/google/yapf
rev: v0.32.0
rev: v0.43.0
hooks:
- id: yapf
args: [--in-place, --verbose]
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.5
rev: v0.9.3
hooks:
- id: ruff
args: [--output-format, github]
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.0
hooks:
- id: codespell
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*'
Expand All @@ -23,7 +23,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.5
rev: v19.1.7
hooks:
- id: clang-format
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))'
Expand All @@ -35,7 +35,7 @@ repos:
- id: pymarkdown
files: docs/.*
- repo: https://github.com/rhysd/actionlint
rev: v1.7.6
rev: v1.7.7
hooks:
- id: actionlint
- repo: local
Expand Down
14 changes: 9 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)
set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
Expand All @@ -296,8 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()

# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
Expand Down Expand Up @@ -351,7 +351,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# 2:4 Sparse Kernels

# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
Expand Down Expand Up @@ -446,6 +446,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()

message(STATUS "Enabling C extension.")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_C_LIBS cuda)
endif()
define_gpu_extension_target(
_C
DESTINATION vllm
Expand All @@ -454,6 +457,7 @@ define_gpu_extension_target(
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
LIBRARIES ${VLLM_C_LIBS}
USE_SABI 3
WITH_SOABI)

Expand Down Expand Up @@ -576,7 +580,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 0aff05f577e8a10086066a00618609199b25231d
GIT_TAG d4e09037abf588af1ec47d0e966b237ee376876c
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
23 changes: 21 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \

#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
# TODO: Restore to base image after FlashInfer AOT wheel fixed
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
ARG CUDA_VERSION=12.4.1
ARG PYTHON_VERSION=3.12
WORKDIR /vllm-workspace
Expand Down Expand Up @@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose

# How to build this FlashInfer wheel:
# $ export FLASHINFER_ENABLE_AOT=1
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
# $ cd flashinfer
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose

RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
fi
COPY examples examples

# Although we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
# install build dependencies for JIT compilation.
# TODO: Remove this once FlashInfer AOT wheel is fixed
COPY requirements-build.txt requirements-build.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-build.txt

#################### vLLM installation IMAGE ####################

#################### TEST IMAGE ####################
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG NIGHTLY_DATE="20250122"
ARG NIGHTLY_DATE="20250124"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"

FROM $BASE_IMAGE
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,8 @@ def main(args: argparse.Namespace):
)

# Traffic
result_json["request_rate"] = (
args.request_rate if args.request_rate < float("inf") else "inf")
result_json["request_rate"] = (args.request_rate if args.request_rate
< float("inf") else "inf")
result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency

Expand Down
43 changes: 28 additions & 15 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ endmacro()
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
# 9.0a to the result.
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
# The result is stored in `OUT_CUDA_ARCHS`.
#
# Example:
Expand All @@ -270,34 +270,47 @@ endmacro()
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})

# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
set(_CUDA_ARCHS "9.0a")
endif()
endif()

list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)

# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
# less or eqault to ARCH
foreach(_ARCH ${CUDA_ARCHS})
set(_TMP_ARCH)
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
set(_TMP_ARCH ${_SRC_ARCH})
else()
break()
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# is less or equal to ARCH (but has the same major version since SASS binary
# compatibility is only forward compatible within the same major version).
foreach(_ARCH ${TGT_CUDA_ARCHS_})
set(_TMP_ARCH)
# Extract the major version of the target arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
# Extract the major version of the source arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
# Check major-version match AND version-less-or-equal
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
set(_TMP_ARCH "${_SRC_ARCH}")
endif()
else()
# If we hit a version greater than the target, we can break
break()
endif()
endforeach()

# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
if (_TMP_ARCH)
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
endif()
endforeach()
if (_TMP_ARCH)
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
endif()
endforeach()

list(REMOVE_DUPLICATES _CUDA_ARCHS)
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
Expand Down
8 changes: 6 additions & 2 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ struct Signal {
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
};

struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
};

struct __align__(16) RankSignals { Signal* signals[8]; };
struct __align__(16) RankSignals {
Signal* signals[8];
};

// like std::array, but aligned
template <typename T, int sz>
Expand Down
8 changes: 4 additions & 4 deletions csrc/moe/marlin_kernels/marlin_moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ __device__ inline FragB dequant<vllm::kU4B8.id()>(int q) {
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
Expand Down Expand Up @@ -182,8 +182,8 @@ __device__ inline FragB dequant<vllm::kU4.id()>(int q) {
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);

const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
Expand Down
4 changes: 3 additions & 1 deletion csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,

extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
token_cnts_t* tokens_cnts =
(token_cnts_t*)(shared_mem + num_experts +
1); // 2d tensor with shape (blockDim.x + 1, num_experts)

for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
Expand Down
16 changes: 8 additions & 8 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) {
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
Expand All @@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {

// Guarantee that the `(a & b) | c` operations are LOP3s.

int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);

typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
Expand All @@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) {
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);

const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
Expand All @@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {

// Guarantee that the `(a & b) | c` operations are LOP3s.

int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);

typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) {
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ __device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
static constexpr uint32_t HI = 0x00f000f0;
static constexpr uint32_t EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
uint32_t t0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
uint32_t t1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
static constexpr uint32_t SUB = 0x64086408;
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/marlin/sparse/common/mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ __device__ inline FragB dequant_4bit(int q) {
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
Expand Down
4 changes: 3 additions & 1 deletion csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions){UNREACHABLE_CODE}
const int max_num_partitions) {
UNREACHABLE_CODE
}

#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support

Expand Down
Loading

0 comments on commit a330e7f

Please sign in to comment.