@@ -65,8 +65,6 @@ ARG PYTORCH_BRANCH
6565ARG PYTORCH_VISION_BRANCH
6666ARG PYTORCH_REPO
6767ARG PYTORCH_VISION_REPO
68- ARG FA_BRANCH
69- ARG FA_REPO
7068RUN git clone ${PYTORCH_REPO} pytorch
7169RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
7270 pip install -r requirements.txt && git submodule update --init --recursive \
@@ -77,14 +75,20 @@ RUN git clone ${PYTORCH_VISION_REPO} vision
7775RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
7876 && python3 setup.py bdist_wheel --dist-dir=dist \
7977 && pip install dist/*.whl
78+ RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
79+ && cp /app/vision/dist/*.whl /app/install
80+
81+ FROM base AS build_fa
82+ ARG FA_BRANCH
83+ ARG FA_REPO
84+ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
85+ pip install /install/*.whl
8086RUN git clone ${FA_REPO}
8187RUN cd flash-attention \
8288 && git checkout ${FA_BRANCH} \
8389 && git submodule update --init \
8490 && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist
85- RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
86- && cp /app/vision/dist/*.whl /app/install \
87- && cp /app/flash-attention/dist/*.whl /app/install
91+ RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install
8892
8993FROM base AS build_aiter
9094ARG AITER_BRANCH
@@ -103,6 +107,8 @@ FROM base AS debs
103107RUN mkdir /app/debs
104108RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
105109 cp /install/*.whl /app/debs
110+ RUN --mount=type=bind,from=build_fa,src=/app/install/,target=/install \
111+ cp /install/*.whl /app/debs
106112RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
107113 cp /install/*.whl /app/debs
108114RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
@@ -111,13 +117,7 @@ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
111117 cp /install/*.whl /app/debs
112118
113119FROM base AS final
114- RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
115- pip install /install/*.whl
116- RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
117- pip install /install/*.whl
118- RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
119- pip install /install/*.whl
120- RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
120+ RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \
121121 pip install /install/*.whl
122122
123123ARG BASE_IMAGE
0 commit comments