Skip to content

Commit 9da4a23

Browse files
RayWang96fzyzcjy
andauthored
Add more GPU architectures support (#112)
* Add more GPU architectures support * Update layout.py * Optimize performance, Add SM90 support, Add 1D2D SM100 support * Add fmtlib submodule at commit 553ec11 --------- Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
1 parent 03d0be3 commit 9da4a23

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+5575
-2954
lines changed

.gitmodules

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[submodule "third-party/cutlass"]
22
path = third-party/cutlass
3-
url = https://github.com/NVIDIA/cutlass.git
3+
url = git@github.com:NVIDIA/cutlass.git
4+
[submodule "third-party/fmt"]
5+
path = third-party/fmt
6+
url = git@github.com:fmtlib/fmt.git

CMakeLists.txt

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,33 @@
11
# NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT
2-
# TODO: add CUDA utils' library via CMake
32
cmake_minimum_required(VERSION 3.10)
43
project(deep_gemm LANGUAGES CXX CUDA)
5-
6-
set(CMAKE_CXX_STANDARD 20)
7-
set(CMAKE_CUDA_STANDARD 20)
84
set(CMAKE_VERBOSE_MAKEFILE ON)
95

10-
find_package(CUDAToolkit REQUIRED)
11-
find_package(pybind11 REQUIRED)
12-
13-
file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }")
14-
execute_process(
15-
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
16-
RESULT_VARIABLE NVCC_RESULT
17-
OUTPUT_VARIABLE NVCC_OUTPUT
18-
ERROR_VARIABLE NVCC_ERROR_OUTPUT
19-
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
20-
)
6+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
7+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
8+
set(CUDA_SEPARABLE_COMPILATION ON)
9+
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
10+
list(APPEND CUDA_NVCC_FLAGS "-O3")
11+
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
2112

22-
if (NVCC_RESULT EQUAL "0")
23-
set(NVCC_SUPPORTS_SM90 TRUE)
24-
message(STATUS "NVCC supports SM90")
25-
else()
26-
message(STATUS "NVCC does not support SM90")
27-
endif()
13+
set(USE_SYSTEM_NVTX on)
14+
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
15+
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
2816

29-
if (NVCC_SUPPORTS_SM90)
30-
set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE)
31-
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
32-
endif()
17+
find_package(CUDAToolkit REQUIRED)
18+
find_package(pybind11 REQUIRED)
3319
find_package(Torch REQUIRED)
3420

35-
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include)
36-
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
37-
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib)
21+
set(CMAKE_CXX_STANDARD 20)
22+
set(CMAKE_CUDA_STANDARD 20)
23+
24+
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
25+
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
26+
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
3827

39-
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
40-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
41-
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -fPIC -DNDEBUG")
42-
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17 -DNDEBUG --ptxas-options=--register-usage-level=10")
28+
# The main Python API entrance
29+
pybind11_add_module(deep_gemm_cpp csrc/python_api.cpp)
30+
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} torch_python cuda)
4331

44-
cuda_add_library(example_gemm STATIC indexing/main.cu)
32+
# Enable kernel code indexing with CMake-based IDEs
33+
cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu)

README.md

Lines changed: 53 additions & 115 deletions
Large diffs are not rendered by default.

csrc/indexing/main.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
2+
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
3+
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
4+
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
5+
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
6+
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
7+
#include <deep_gemm/impls/smxx_layout.cuh>
8+
9+
using namespace deep_gemm;
10+
11+
int main() {
12+
return 0;
13+
}

csrc/jit/cache.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <filesystem>
4+
#include <memory>
5+
#include <unordered_map>
6+
7+
#include "kernel_runtime.hpp"
8+
9+
namespace deep_gemm {
10+
11+
class KernelRuntimeCache {
12+
std::unordered_map<std::filesystem::path, std::shared_ptr<KernelRuntime>> cache;
13+
14+
public:
15+
// TODO: consider cache capacity
16+
KernelRuntimeCache() = default;
17+
18+
std::shared_ptr<KernelRuntime> get(const std::filesystem::path& dir_path) {
19+
// Hit the runtime cache
20+
if (const auto& iterator = cache.find(dir_path); iterator != cache.end())
21+
return iterator->second;
22+
23+
if (KernelRuntime::check_validity(dir_path))
24+
return cache[dir_path] = std::make_shared<KernelRuntime>(dir_path);
25+
return nullptr;
26+
}
27+
};
28+
29+
static auto kernel_runtime_cache = std::make_shared<KernelRuntimeCache>();
30+
31+
} // namespace deep_gemm

csrc/jit/compiler.hpp

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#pragma once
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <filesystem>
5+
#include <fstream>
6+
#include <regex>
7+
#include <string>
8+
9+
#include "../utils/exception.hpp"
10+
#include "../utils/format.hpp"
11+
#include "../utils/hash.hpp"
12+
#include "../utils/system.hpp"
13+
#include "cache.hpp"
14+
#include "device_runtime.hpp"
15+
16+
namespace deep_gemm {
17+
18+
class Compiler {
19+
std::string library_version;
20+
std::filesystem::path library_root_path;
21+
22+
std::string get_library_version() const {
23+
// Recursively walk through all subdirectories and update hash
24+
std::stringstream ss;
25+
for (const auto& entry: std::filesystem::recursive_directory_iterator(library_include_path / "deep_gemm")) {
26+
if (entry.is_regular_file() and entry.path().extension() == ".cuh") {
27+
std::ifstream file(entry.path(), std::ios::binary);
28+
std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator<char>());
29+
ss << content;
30+
}
31+
}
32+
return get_hex_digest(ss.str());
33+
}
34+
35+
public:
36+
std::string signature, flags;
37+
std::filesystem::path library_include_path;
38+
std::filesystem::path cache_dir_path;
39+
40+
explicit Compiler(const std::filesystem::path& library_root_path) {
41+
// Static library paths
42+
this->library_root_path = library_root_path;
43+
this->library_include_path = library_root_path / "include";
44+
this->library_version = get_library_version();
45+
46+
// Cache settings
47+
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
48+
if (const auto& env_cache_dir_path = get_env<std::string>("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty())
49+
cache_dir_path = env_cache_dir_path;
50+
51+
// The compiler flags applied to all derived compilers
52+
signature = "unknown-compiler";
53+
std::string ptxas_flags = "--ptxas-options=--register-usage-level=10";
54+
if (get_env<int>("DG_JIT_PTXAS_VERBOSE", 0))
55+
ptxas_flags += ",--verbose";
56+
flags = fmt::format("-std=c++20 --diag-suppress=39,161,174,177,186,940 {}", ptxas_flags);
57+
}
58+
59+
virtual ~Compiler() = default;
60+
61+
std::filesystem::path make_tmp_dir() const {
62+
return make_dirs(cache_dir_path / "tmp");
63+
}
64+
65+
std::filesystem::path get_tmp_file_path() const {
66+
return make_tmp_dir() / get_uuid();
67+
}
68+
69+
void put(const std::filesystem::path& path, const std::string& data) const {
70+
const auto tmp_file_path = get_tmp_file_path();
71+
72+
// Write into the temporary file
73+
std::ofstream out(tmp_file_path, std::ios::binary);
74+
DG_HOST_ASSERT(out.write(data.data(), data.size()));
75+
out.close();
76+
77+
// Atomically replace
78+
std::filesystem::rename(tmp_file_path, path);
79+
}
80+
81+
std::shared_ptr<KernelRuntime> build(const std::string& name, const std::string& code) const {
82+
const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code);
83+
const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature));
84+
85+
// Hit the runtime cache
86+
if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr)
87+
return runtime;
88+
89+
// Create the kernel directory
90+
make_dirs(dir_path);
91+
92+
// Compile into a temporary CUBIN
93+
const auto tmp_cubin_path = get_tmp_file_path();
94+
compile(code, dir_path, tmp_cubin_path);
95+
96+
// Replace into the cache directory
97+
make_dirs(dir_path);
98+
std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin");
99+
100+
// Put into the runtime cache
101+
const auto& runtime = kernel_runtime_cache->get(dir_path);
102+
DG_HOST_ASSERT(runtime != nullptr);
103+
return runtime;
104+
}
105+
106+
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0;
107+
};
108+
109+
class NVCCCompiler final: public Compiler {
110+
std::filesystem::path nvcc_path;
111+
112+
std::pair<int, int> get_nvcc_version() const {
113+
DG_HOST_ASSERT(std::filesystem::exists(nvcc_path));
114+
115+
// Call the version command
116+
const auto& command = std::string(nvcc_path) + " --version";
117+
const auto& [return_code, output] = call_external_command(command);
118+
DG_HOST_ASSERT(return_code == 0);
119+
120+
// The version should be at least 12.3, for the best performance with 12.9
121+
int major, minor;
122+
std::smatch match;
123+
DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))")));
124+
std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor);
125+
DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3");
126+
if (major < 12 or (major == 12 and minor < 9))
127+
printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance");
128+
return {major, minor};
129+
}
130+
131+
public:
132+
NVCCCompiler(const std::filesystem::path& library_root_path,
133+
const std::filesystem::path& cuda_home_path_by_torch):
134+
Compiler(library_root_path) {
135+
// Override the compiler signature
136+
nvcc_path = cuda_home_path_by_torch / "bin" / "nvcc";
137+
if (const auto& env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty())
138+
nvcc_path = env_nvcc_path;
139+
const auto& [nvcc_major, nvcc_minor] = get_nvcc_version();
140+
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
141+
142+
// The override the compiler flags
143+
flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a "
144+
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
145+
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
146+
flags, library_include_path.c_str(), device_runtime->get_arch());
147+
}
148+
149+
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
150+
// Write the code into the cache directory
151+
const auto& code_path = dir_path / "kernel.cu";
152+
put(code_path, code);
153+
154+
// Compile
155+
const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
156+
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
157+
printf("Running NVCC command: %s", command.c_str());
158+
const auto& [return_code, output] = call_external_command(command);
159+
if (return_code != 0) {
160+
printf("NVCC compilation failed: %s", output.c_str());
161+
DG_HOST_ASSERT(false and "NVCC compilation failed");
162+
}
163+
164+
// Print PTXAS log
165+
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0))
166+
printf("%s", output.c_str());
167+
}
168+
};
169+
170+
static std::shared_ptr<Compiler> compiler = nullptr;
171+
172+
} // namespace deep_gemm

csrc/jit/device_runtime.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
5+
#include "../utils/exception.hpp"
6+
7+
namespace deep_gemm {
8+
9+
class DeviceRuntime {
10+
int num_sms = 0;
11+
std::shared_ptr<cudaDeviceProp> cached_prop;
12+
13+
public:
14+
explicit DeviceRuntime() = default;
15+
16+
std::shared_ptr<cudaDeviceProp> get_prop() {
17+
if (cached_prop == nullptr)
18+
cached_prop = std::make_shared<cudaDeviceProp>(*at::cuda::getCurrentDeviceProperties());
19+
return cached_prop;
20+
}
21+
22+
std::pair<int, int> get_arch_pair() {
23+
const auto prop = get_prop();
24+
return {prop->major, prop->minor};
25+
}
26+
27+
int get_arch() {
28+
const auto& [major, minor] = get_arch_pair();
29+
return major * 10 + minor;
30+
}
31+
32+
int get_arch_major() {
33+
return get_arch_pair().first;
34+
}
35+
36+
void set_num_sms(const int& new_num_sms) {
37+
DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount);
38+
num_sms = new_num_sms;
39+
}
40+
41+
int get_num_sms() {
42+
if (num_sms == 0)
43+
num_sms = get_prop()->multiProcessorCount;
44+
return num_sms;
45+
}
46+
};
47+
48+
static auto device_runtime = std::make_shared<DeviceRuntime>();
49+
50+
} // namespace deep_gemm

0 commit comments

Comments
 (0)