Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
# --------------
# DLL: TEMP since aiter is volatile. When base is locked still build aiter
FROM base AS build_aiter

# Rebuild LLVM before AITER
ENV FLATMM_HIP_CLANG_PATH=/app/git/llvm-project/build/bin/
RUN mkdir -p /app/git && cd /app/git && git clone https://github.com/jrbyrnes/llvm-project.git \
&& cd llvm-project && git checkout c7e653b3343f7757920e7581a10b2015b98af647 \
&& mkdir build && cd build \
&& cmake -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm \
&& make -j64

ARG AITER_BRANCH
ARG AITER_REPO
# RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
Expand Down Expand Up @@ -117,12 +126,7 @@ RUN cd /vllm-workspace \
# -----------------------
# Final vLLM image
FROM base AS final
ENV FLATMM_HIP_CLANG_PATH=/app/git/llvm-project/build/bin/
RUN mkdir -p /app/git && cd /app/git && git clone https://github.com/jrbyrnes/llvm-project.git \
&& cd llvm-project && git checkout c7e653b3343f7757920e7581a10b2015b98af647 \
&& mkdir build && cd build \
&& cmake -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm \
&& make -j64

RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually remove it so that later steps of numpy upgrade can continue
Expand Down
15 changes: 10 additions & 5 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def should_use_flashinfer_mxfp4():
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4:
import aiter
from aiter.fused_moe import fused_topk, moe_sorting
from aiter.ops.shuffle import shuffle_mxfp4_weight, shuffle_mxfp4_scale
from aiter.ops.shuffle import shuffle_weight_a16w4, shuffle_scale_a16w4

class Mxfp4Config(QuantizationConfig):

Expand Down Expand Up @@ -498,11 +498,16 @@ def swap_every_two_rows(x, axis=-1):
e, n, k = w13_aiter_weight.shape
w13_aiter_weight = w13_aiter_weight.view(e, n // 2, 2, k).permute(0, 2, 1, 3).contiguous().view(e, n, k)
w13_aiter_scale = w13_aiter_scale.view(e, n // 2, 2, -1).permute(0, 2, 1, 3).contiguous().view(e, n, -1)

w13_aiter_weight = w13_aiter_weight.view(torch.float4_e2m1fn_x2)
w13_aiter_scale = w13_aiter_scale.view(-1, w13_aiter_scale.shape[-1])
w2_aiter_weight = w2_aiter_weight.view(torch.float4_e2m1fn_x2)
w2_aiter_scale = w2_aiter_scale.view(-1, w2_aiter_scale.shape[-1])

self.w13_weight_aiter_tensor = shuffle_mxfp4_weight(w13_aiter_weight, 16, True)
self.w13_scale_aiter_tensor = shuffle_mxfp4_scale(w13_aiter_scale, True)
self.w2_weight_aiter_tensor = shuffle_mxfp4_weight(w2_aiter_weight, 16, False)
self.w2_scale_aiter_tensor = shuffle_mxfp4_scale(w2_aiter_scale, False)
self.w13_weight_aiter_tensor = shuffle_weight_a16w4(w13_aiter_weight, 16, True)
self.w13_scale_aiter_tensor = shuffle_scale_a16w4(w13_aiter_scale, self.num_experts, True)
self.w2_weight_aiter_tensor = shuffle_weight_a16w4(w2_aiter_weight, 16, False)
self.w2_scale_aiter_tensor = shuffle_scale_a16w4(w2_aiter_scale, self.num_experts, False)
self.w13_bias_aiter_tensor = layer.w13_bias.view(-1, n // 2, 2).permute(0, 2, 1).contiguous().view(-1, n)
else:
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
Expand Down
Loading