Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm Sparse Marlin Kernels #1206

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def read_version(file_path="version.txt"):
CUDAExtension,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might enjoy stack based PR development https://github.com/modularml/stack-pr

BuildExtension,
CUDA_HOME,
ROCM_HOME,
IS_WINDOWS
)

IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)

def get_extensions():
debug_mode = os.getenv('DEBUG', '0') == '1'
Expand All @@ -57,11 +59,11 @@ def get_extensions():

if not torch.cuda.is_available():
print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions")
if CUDA_HOME is None and torch.cuda.is_available():
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
if CUDA_HOME is None or not IS_ROCM and torch.cuda.is_available():
print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions")
print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit")

use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None)
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
Expand All @@ -71,15 +73,14 @@ def get_extensions():
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
}
if use_cuda and not IS_ROCM:
extra_compile_args["nvcc"] = ["-O3" if not debug_mode else "-O0", "-t=0",]

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
if "nvcc" in extra_compile_args:
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])

else:
Expand Down Expand Up @@ -107,9 +108,19 @@ def get_extensions():
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True))

if use_cuda:

extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")

hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True))

if not IS_ROCM and use_cuda:
sources += cuda_sources

# TOOD: Remove this and use what CUDA has once we fix all the builds.
if IS_ROCM and use_cuda:
sources += hip_sources

## TODO: remove this condition and use what we have in CUDA once we fix the individual builds.
ext_modules = [
extension(
"torchao._C",
Expand Down
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static constexpr int min_thread_n = 128;
static constexpr int tile_size = 16;
static constexpr int max_par = 64;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && !defined(USE_ROCM)

template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock
Expand Down
102 changes: 93 additions & 9 deletions torchao/csrc/cuda/sparse_marlin/mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,28 @@
#include "base.h"

namespace torchao {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xw285cornell looking for some quick advice, do you recommend we support AMD by adding conditional compilation flags to our existing cuda kernels or be OK with some more copy paste?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chatted offline and indeed ifdefs are the way to go

#ifdef USE_ROCM
#include <hip/hip_runtime.h>

// Convert generic pointer to shared memory address for ROCm
template<typename T>
__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) {
// First get the address as a size_t to handle all pointer sizes
size_t addr = reinterpret_cast<size_t>(ptr);

// Extract the lower 32 bits which represent the shared memory offset
// This is safe because shared memory addresses are always within 32-bit range
return static_cast<uint32_t>(addr & 0xFFFFFFFF);
}
#else
// For CUDA, use the native intrinsic
template<typename T>
__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}
#endif

// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
Expand All @@ -27,91 +49,144 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const bool zfill = false) {
const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
#endif
}

__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
#endif
}

// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
#else
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
#endif
}

// Async copy fence.
__device__ inline void cp_async_fence() {
#ifdef USE_ROCM
__builtin_amdgcn_s_waitcnt(0);
#else
asm volatile("cp.async.commit_group;\n" ::);
#endif
}

// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
#ifdef USE_ROCM
// For AMD GPUs, we use s_waitcnt
// This waits for all outstanding memory operations to complete
__builtin_amdgcn_s_waitcnt(0);
#else
// For NVIDIA GPUs, use the original instruction
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
#endif
}

// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b128 %0, %1 offset:0\n"
"ds_read_b128 %2, %1 offset:16\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
: "v"(smem));
#else
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
#endif
}

__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b64 %0, %2 offset:0\n"
: "=v"(a[0]), "=v"(a[1])
: "v"(smem));
#else
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
: "r"(smem));
#endif
}

// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b128 %0, %1 offset:0\n"
"ds_read_b128 %2, %1 offset:16\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
: "v"(smem));
#else
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
#endif
}

// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
do {
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
#ifdef USE_ROCM
asm volatile("flat_load_dword %0, %1 glc\n\t"
"s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t"
: "=v"(state)
: "v"(lock));
#else
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
#endif
} while (state != count);
}
__syncthreads();
}
Expand All @@ -127,10 +202,19 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
#ifdef USE_ROCM
asm volatile("s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t"
"s_memrealtime\n\t"
"s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t"
"flat_atomic_add_i32 %0, %1\n\t"
: "+v"(*lock)
: "v"(val));
#else
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
#endif
}
}
} // namespace torchao
} // namespace torchao
Loading
Loading