@@ -62,8 +62,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
6262 #
6363 set (SRCS ${ORIG_SRCS} )
6464 set (CXX_SRCS ${ORIG_SRCS} )
65- list (FILTER SRCS EXCLUDE REGEX "\. (cc)|(cpp)$" )
66- list (FILTER CXX_SRCS INCLUDE REGEX "\. (cc)|(cpp)$" )
65+ list (FILTER SRCS EXCLUDE REGEX "\. (cc)|(cpp)|(hip) $" )
66+ list (FILTER CXX_SRCS INCLUDE REGEX "\. (cc)|(cpp)|(hip) $" )
6767
6868 #
6969 # Generate ROCm/HIP source file names from CUDA file names.
@@ -80,7 +80,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
8080 set (CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR} /csrc)
8181 add_custom_target (
8282 hipify${NAME}
83- COMMAND ${CMAKE_SOURCE_DIR} /cmake/hipify.py -p ${CMAKE_SOURCE_DIR} /csrc -o ${CSRC_BUILD_DIR} ${SRCS}
83+ COMMAND ${Python_EXECUTABLE} ${ CMAKE_SOURCE_DIR} /cmake/hipify.py -p ${CMAKE_SOURCE_DIR} /csrc -o ${CSRC_BUILD_DIR} ${SRCS}
8484 DEPENDS ${CMAKE_SOURCE_DIR} /cmake/hipify.py ${SRCS}
8585 BYPRODUCTS ${HIP_SRCS}
8686 COMMENT "Running hipify on ${NAME} extension source files." )
@@ -232,11 +232,26 @@ macro(set_gencode_flags_for_srcs)
232232 "${multiValueArgs} " ${ARGN} )
233233
234234 foreach (_ARCH ${arg_CUDA_ARCHS} )
235- string (REPLACE "." "" _ARCH "${_ARCH} " )
236- set_gencode_flag_for_srcs(
237- SRCS ${arg_SRCS}
238- ARCH "compute_${_ARCH} "
239- CODE "sm_${_ARCH} " )
235+ # handle +PTX suffix: generate both sm and ptx codes if requested
236+ string (FIND "${_ARCH} " "+PTX" _HAS_PTX)
237+ if (NOT _HAS_PTX EQUAL -1)
238+ string (REPLACE "+PTX" "" _BASE_ARCH "${_ARCH} " )
239+ string (REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH} " )
240+ set_gencode_flag_for_srcs(
241+ SRCS ${arg_SRCS}
242+ ARCH "compute_${_STRIPPED_ARCH} "
243+ CODE "sm_${_STRIPPED_ARCH} " )
244+ set_gencode_flag_for_srcs(
245+ SRCS ${arg_SRCS}
246+ ARCH "compute_${_STRIPPED_ARCH} "
247+ CODE "compute_${_STRIPPED_ARCH} " )
248+ else ()
249+ string (REPLACE "." "" _STRIPPED_ARCH "${_ARCH} " )
250+ set_gencode_flag_for_srcs(
251+ SRCS ${arg_SRCS}
252+ ARCH "compute_${_STRIPPED_ARCH} "
253+ CODE "sm_${_STRIPPED_ARCH} " )
254+ endif ()
240255 endforeach ()
241256
242257 if (${arg_BUILD_PTX_FOR_ARCH} )
@@ -255,15 +270,18 @@ endmacro()
255270#
256271# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
257272# `<major>.<minor>[letter]` compute the "loose intersection" with the
258- # `TGT_CUDA_ARCHS` list of gencodes.
273+ # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
274+ # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
275+ # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
276+ # architecture in `SRC_CUDA_ARCHS`.
259277# The loose intersection is defined as:
260278# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
261279# where `<=` is the version comparison operator.
262280# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
263281# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
264- # We have special handling for 9 .0a, if 9 .0a is in `SRC_CUDA_ARCHS` and 9 .0 is
265- # in `TGT_CUDA_ARCHS` then we should remove 9 .0a from `SRC_CUDA_ARCHS` and add
266- # 9 .0a to the result (and remove 9 .0 from TGT_CUDA_ARCHS).
282+ # We have special handling for x .0a, if x .0a is in `SRC_CUDA_ARCHS` and x .0 is
283+ # in `TGT_CUDA_ARCHS` then we should remove x .0a from `SRC_CUDA_ARCHS` and add
284+ # x .0a to the result (and remove x .0 from TGT_CUDA_ARCHS).
267285# The result is stored in `OUT_CUDA_ARCHS`.
268286#
269287# Example:
@@ -272,36 +290,63 @@ endmacro()
272290# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
273291# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
274292#
293+ # Example With PTX:
294+ # SRC_CUDA_ARCHS="8.0+PTX"
295+ # TGT_CUDA_ARCHS="9.0"
296+ # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
297+ # OUT_CUDA_ARCHS="8.0+PTX"
298+ #
275299function (cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
276- list (REMOVE_DUPLICATES SRC_CUDA_ARCHS)
277- set (TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS} )
300+ set (_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS} " )
301+ set (_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS} )
302+
303+ # handle +PTX suffix: separate base arch for matching, record PTX requests
304+ set (_PTX_ARCHS)
305+ foreach (_arch ${_SRC_CUDA_ARCHS} )
306+ if (_arch MATCHES "\\ +PTX$" )
307+ string (REPLACE "+PTX" "" _base "${_arch} " )
308+ list (APPEND _PTX_ARCHS "${_base} " )
309+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch} " )
310+ list (APPEND _SRC_CUDA_ARCHS "${_base} " )
311+ endif ()
312+ endforeach ()
313+ list (REMOVE_DUPLICATES _PTX_ARCHS)
314+ list (REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
278315
279- # if 9 .0a is in SRC_CUDA_ARCHS and 9 .0 is in CUDA_ARCHS then we should
280- # remove 9 .0a from SRC_CUDA_ARCHS and add 9 .0a to _CUDA_ARCHS
316+ # if x .0a is in SRC_CUDA_ARCHS and x .0 is in CUDA_ARCHS then we should
317+ # remove x .0a from SRC_CUDA_ARCHS and add x .0a to _CUDA_ARCHS
281318 set (_CUDA_ARCHS)
282- if ("9.0a" IN_LIST SRC_CUDA_ARCHS )
283- list (REMOVE_ITEM SRC_CUDA_ARCHS "9.0a" )
284- if ("9.0" IN_LIST TGT_CUDA_ARCHS_ )
285- list (REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0" )
319+ if ("9.0a" IN_LIST _SRC_CUDA_ARCHS )
320+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a" )
321+ if ("9.0" IN_LIST TGT_CUDA_ARCHS )
322+ list (REMOVE_ITEM _TGT_CUDA_ARCHS "9.0" )
286323 set (_CUDA_ARCHS "9.0a" )
287324 endif ()
288325 endif ()
289326
290- list (SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
327+ if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
328+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a" )
329+ if ("10.0" IN_LIST TGT_CUDA_ARCHS)
330+ list (REMOVE_ITEM _TGT_CUDA_ARCHS "10.0" )
331+ set (_CUDA_ARCHS "10.0a" )
332+ endif ()
333+ endif ()
334+
335+ list (SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
291336
292337 # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
293338 # is less or equal to ARCH (but has the same major version since SASS binary
294339 # compatibility is only forward compatible within the same major version).
295- foreach (_ARCH ${TGT_CUDA_ARCHS_ } )
340+ foreach (_ARCH ${_TGT_CUDA_ARCHS } )
296341 set (_TMP_ARCH)
297342 # Extract the major version of the target arch
298343 string (REGEX REPLACE "^([0-9]+)\\ ..*$" "\\ 1" TGT_ARCH_MAJOR "${_ARCH} " )
299- foreach (_SRC_ARCH ${SRC_CUDA_ARCHS } )
344+ foreach (_SRC_ARCH ${_SRC_CUDA_ARCHS } )
300345 # Extract the major version of the source arch
301346 string (REGEX REPLACE "^([0-9]+)\\ ..*$" "\\ 1" SRC_ARCH_MAJOR "${_SRC_ARCH} " )
302- # Check major- version match AND version -less-or-equal
347+ # Check version-less-or-equal, and allow PTX arches to match across majors
303348 if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
304- if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
349+ if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
305350 set (_TMP_ARCH "${_SRC_ARCH} " )
306351 endif ()
307352 else ()
@@ -317,6 +362,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
317362 endforeach ()
318363
319364 list (REMOVE_DUPLICATES _CUDA_ARCHS)
365+
366+ # reapply +PTX suffix to architectures that requested PTX
367+ set (_FINAL_ARCHS)
368+ foreach (_arch ${_CUDA_ARCHS} )
369+ if (_arch IN_LIST _PTX_ARCHS)
370+ list (APPEND _FINAL_ARCHS "${_arch} +PTX" )
371+ else ()
372+ list (APPEND _FINAL_ARCHS "${_arch} " )
373+ endif ()
374+ endforeach ()
375+ set (_CUDA_ARCHS ${_FINAL_ARCHS} )
376+
320377 set (${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
321378endfunction ()
322379
0 commit comments