Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ find_package(pybind11 REQUIRED)

file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }")
execute_process(
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -gencode arch=compute_120,code=sm_120 -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
RESULT_VARIABLE NVCC_RESULT
OUTPUT_VARIABLE NVCC_OUTPUT
ERROR_VARIABLE NVCC_ERROR_OUTPUT
Expand All @@ -27,8 +27,11 @@ else()
endif()

if (NVCC_SUPPORTS_SM90)
set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE)
set(TORCH_CUDA_ARCH_LIST "8.6;9.0a" CACHE STRING "Add arch tag 90a to NVCC" FORCE) # TODO: Check if 9.0a is correct for sm_90a, it might be just 9.0
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
# Add Blackwell support if NVCC supports it (determined by the test_cuda.cu compilation)
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_120,code=sm_120")
set(TORCH_CUDA_ARCH_LIST "${TORCH_CUDA_ARCH_LIST};12.0" CACHE STRING "Add arch tag 120 to NVCC" FORCE) # TODO: Check if 12.0 is correct for sm_120
endif()
find_package(Torch REQUIRED)

Expand Down
3 changes: 2 additions & 1 deletion deep_gemm/jit/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def flags(cls) -> List[str]:
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a',
'-gencode=arch=compute_120,code=sm_120',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}']

Expand Down Expand Up @@ -230,7 +231,7 @@ def include_dirs() -> List[str]:
@classmethod
def flags(cls) -> List[str]:
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
'--gpu-architecture=compute_90a', '--gpu-architecture=compute_120', '-default-device']
# NOTES: PCH is vital for compilation speed
if cls.__version__() >= (12, 8):
flags += ['--pch']
Expand Down