Skip to content

Commit 566d676

Browse files
FA2 8.0 PTX (#69)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 099599c commit 566d676

File tree

2 files changed

+83
-26
lines changed

2 files changed

+83
-26
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ if (FA2_ENABLED)
135135

136136
# For CUDA we set the architectures on a per file basis
137137
if (VLLM_GPU_LANG STREQUAL "CUDA")
138-
cuda_archs_loose_intersection(FA2_ARCHS "8.0;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
138+
cuda_archs_loose_intersection(FA2_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
139139
message(STATUS "FA2_ARCHS: ${FA2_ARCHS}")
140140

141141
set_gencode_flags_for_srcs(

cmake/utils.cmake

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
6262
#
6363
set(SRCS ${ORIG_SRCS})
6464
set(CXX_SRCS ${ORIG_SRCS})
65-
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
66-
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
65+
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
66+
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
6767

6868
#
6969
# Generate ROCm/HIP source file names from CUDA file names.
@@ -80,7 +80,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
8080
set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
8181
add_custom_target(
8282
hipify${NAME}
83-
COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
83+
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
8484
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
8585
BYPRODUCTS ${HIP_SRCS}
8686
COMMENT "Running hipify on ${NAME} extension source files.")
@@ -232,11 +232,26 @@ macro(set_gencode_flags_for_srcs)
232232
"${multiValueArgs}" ${ARGN} )
233233

234234
foreach(_ARCH ${arg_CUDA_ARCHS})
235-
string(REPLACE "." "" _ARCH "${_ARCH}")
236-
set_gencode_flag_for_srcs(
237-
SRCS ${arg_SRCS}
238-
ARCH "compute_${_ARCH}"
239-
CODE "sm_${_ARCH}")
235+
# handle +PTX suffix: generate both sm and ptx codes if requested
236+
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
237+
if(NOT _HAS_PTX EQUAL -1)
238+
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
239+
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
240+
set_gencode_flag_for_srcs(
241+
SRCS ${arg_SRCS}
242+
ARCH "compute_${_STRIPPED_ARCH}"
243+
CODE "sm_${_STRIPPED_ARCH}")
244+
set_gencode_flag_for_srcs(
245+
SRCS ${arg_SRCS}
246+
ARCH "compute_${_STRIPPED_ARCH}"
247+
CODE "compute_${_STRIPPED_ARCH}")
248+
else()
249+
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
250+
set_gencode_flag_for_srcs(
251+
SRCS ${arg_SRCS}
252+
ARCH "compute_${_STRIPPED_ARCH}"
253+
CODE "sm_${_STRIPPED_ARCH}")
254+
endif()
240255
endforeach()
241256

242257
if (${arg_BUILD_PTX_FOR_ARCH})
@@ -255,15 +270,18 @@ endmacro()
255270
#
256271
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
257272
# `<major>.<minor>[letter]` compute the "loose intersection" with the
258-
# `TGT_CUDA_ARCHS` list of gencodes.
273+
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
274+
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
275+
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
276+
# architecture in `SRC_CUDA_ARCHS`.
259277
# The loose intersection is defined as:
260278
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
261279
# where `<=` is the version comparison operator.
262280
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
263281
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
264-
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
265-
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
266-
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
282+
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
283+
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
284+
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
267285
# The result is stored in `OUT_CUDA_ARCHS`.
268286
#
269287
# Example:
@@ -272,36 +290,63 @@ endmacro()
272290
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
273291
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
274292
#
293+
# Example With PTX:
294+
# SRC_CUDA_ARCHS="8.0+PTX"
295+
# TGT_CUDA_ARCHS="9.0"
296+
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
297+
# OUT_CUDA_ARCHS="8.0+PTX"
298+
#
275299
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
276-
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
277-
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
300+
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
301+
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
302+
303+
# handle +PTX suffix: separate base arch for matching, record PTX requests
304+
set(_PTX_ARCHS)
305+
foreach(_arch ${_SRC_CUDA_ARCHS})
306+
if(_arch MATCHES "\\+PTX$")
307+
string(REPLACE "+PTX" "" _base "${_arch}")
308+
list(APPEND _PTX_ARCHS "${_base}")
309+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
310+
list(APPEND _SRC_CUDA_ARCHS "${_base}")
311+
endif()
312+
endforeach()
313+
list(REMOVE_DUPLICATES _PTX_ARCHS)
314+
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
278315

279-
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
280-
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
316+
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
317+
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
281318
set(_CUDA_ARCHS)
282-
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
283-
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
284-
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
285-
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
319+
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
320+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
321+
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
322+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
286323
set(_CUDA_ARCHS "9.0a")
287324
endif()
288325
endif()
289326

290-
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
327+
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
328+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
329+
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
330+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
331+
set(_CUDA_ARCHS "10.0a")
332+
endif()
333+
endif()
334+
335+
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
291336

292337
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
293338
# is less or equal to ARCH (but has the same major version since SASS binary
294339
# compatibility is only forward compatible within the same major version).
295-
foreach(_ARCH ${TGT_CUDA_ARCHS_})
340+
foreach(_ARCH ${_TGT_CUDA_ARCHS})
296341
set(_TMP_ARCH)
297342
# Extract the major version of the target arch
298343
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
299-
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
344+
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
300345
# Extract the major version of the source arch
301346
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
302-
# Check major-version match AND version-less-or-equal
347+
# Check version-less-or-equal, and allow PTX arches to match across majors
303348
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
304-
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
349+
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
305350
set(_TMP_ARCH "${_SRC_ARCH}")
306351
endif()
307352
else()
@@ -317,6 +362,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
317362
endforeach()
318363

319364
list(REMOVE_DUPLICATES _CUDA_ARCHS)
365+
366+
# reapply +PTX suffix to architectures that requested PTX
367+
set(_FINAL_ARCHS)
368+
foreach(_arch ${_CUDA_ARCHS})
369+
if(_arch IN_LIST _PTX_ARCHS)
370+
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
371+
else()
372+
list(APPEND _FINAL_ARCHS "${_arch}")
373+
endif()
374+
endforeach()
375+
set(_CUDA_ARCHS ${_FINAL_ARCHS})
376+
320377
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
321378
endfunction()
322379

0 commit comments

Comments
 (0)