Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
pip install -r requirements.txt && git submodule update --init --recursive \
Expand All @@ -77,14 +75,20 @@ RUN git clone ${PYTORCH_VISION_REPO} vision
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install

FROM base AS build_fa
ARG FA_BRANCH
ARG FA_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
RUN git clone ${FA_REPO}
RUN cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install

FROM base AS build_aiter
ARG AITER_BRANCH
Expand All @@ -103,6 +107,8 @@ FROM base AS debs
RUN mkdir /app/debs
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_fa,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
Expand All @@ -111,13 +117,7 @@ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
cp /install/*.whl /app/debs

FROM base AS final
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \
pip install /install/*.whl

ARG BASE_IMAGE
Expand Down