Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def benchmark(m: int, k: int, n: int):

for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, n, k))
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("fp6_llm_benchmark_results.csv", index=False)
Expand Down
85 changes: 59 additions & 26 deletions torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,35 @@
// limitations under the License.
//
// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu
//
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory
//

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 // at least Turing

#include "kernel_matmul.cuh"
#include "kernel_reduction.cuh"

#include <stdio.h>
#include <assert.h>

inline bool isSM75GPU() {
int device;
cudaError_t err = cudaGetDevice(&device);
if (err != cudaSuccess) {
return false;
}

cudaDeviceProp props;
err = cudaGetDeviceProperties(&props, device);
if (err != cudaSuccess) {
return false;
}

return (props.major == 7) && (props.minor == 5);
}

template<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA>
static void Kernel_Ex(cudaStream_t stream,
const uint4 *Weight,
Expand Down Expand Up @@ -80,38 +100,51 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream,
if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128;
if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128;

if (Split_K == 1) {
switch (N_PowerOf2) {
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
if (isSM75GPU() && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) {
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
if (Split_K == 1) {
Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);
} else {
Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);
}
}
else {
switch (N_PowerOf2) {
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
} else {
if (Split_K == 1) {
switch (N_PowerOf2) {
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
}
}
else {
switch (N_PowerOf2) {
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
}
}
}

if (Split_K != 1) {
// Reduction for SplitK
dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1);
dim3 BlockDim(WARP_SIZE, 1, 1);
SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(C, Reduction_Workspace, M_Global, N_Global, Split_K);
}

return cudaGetLastError();
}

Expand Down
10 changes: 10 additions & 0 deletions torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
// limitations under the License.
//
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh
//
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
// - Added __CUDA_ARCH__ guards such that async operations are only executed for SM80 and up
//

#include "configs.h"
#include "utils_gmem.cuh"
Expand Down Expand Up @@ -140,7 +144,9 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
for(int j=0; j<REG_PER_THREAD_C_TENSOR_16_16; j++)
c[i][j] = 0.0f;
//
#if __CUDA_ARCH__ >= 800
cp_async_wait_all();
#endif
__syncthreads();

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -175,12 +181,16 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
if(USE_SEG_4BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy);
// copying B tile from GlobalMemory to SharedMemory
CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS> (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy);
#if __CUDA_ARCH__ >= 800
cp_async_group_commit();
#endif
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2);
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3);
// Barriers and Synchronizations
#if __CUDA_ARCH__ >= 800
cp_async_wait_group<PIPELINE_LEVEL_GMEM-2>();
#endif
__syncthreads();
core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0);
// Updating global PTRs
Expand Down
35 changes: 35 additions & 0 deletions torchao/csrc/cuda/fp6_llm/ptx_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
// limitations under the License.
//
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh
//
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
// - Replaced m16n8k16 Tensor core operation with two m16n8k8 operations
// - Accounted for a difference in expected parameters for the ldmatrix operation

/***************************************************************************
* Copyright 2023 The FLash-LLM Authors. All rights reserved.
Expand Down Expand Up @@ -55,6 +59,14 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
assert( warp_start_col==0 );
#endif

#if __CUDA_ARCH__ == 750
if (TilingConfig::WARP_COL_MMA_TENSORS==1) {
// For .target sm_75, all threads must contain valid addresses for the 'ldmatrix' op. below. Otherwise, the behavior is undefined.
// See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix
// To avoid this, we make threads 16-32 point to the same smem addresses as threads 0-15 by changing the lane id.
lane_id = lane_id % 16;
}
#endif
int col = (lane_id%8) + (lane_id/16)*8;
int row = (lane_id%16) / 8 * 8;
uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row]));
Expand All @@ -80,6 +92,28 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[
__device__ __forceinline__ void
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
{
#if __CUDA_ARCH__ == 750
// m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops.
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5},"
"{ %6 },"
"{ %7, %8, %9, %10 };"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]),
"r"(b[0]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5},"
"{ %6 },"
"{ %7, %8, %9, %10 };"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[2]), "r"(a[3]),
"r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));

#else
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
"{ %4, %5, %6, %7 },"
Expand All @@ -89,6 +123,7 @@ MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
"r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
#endif
}

#endif
21 changes: 20 additions & 1 deletion torchao/csrc/cuda/fp6_llm/utils_gmem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
// limitations under the License.
//
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
//
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
// - Replaced asynchronous copy operations with vectorized loads
//

#ifndef UTILS_GMEM_CUH
#define UTILS_GMEM_CUH
Expand All @@ -39,7 +43,15 @@ __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR,
GPTR_HALF += lane_id*8;
#pragma unroll
for(int i=0; i<SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE/16; i++) {
#if __CUDA_ARCH__ == 750
if (pred_guard) {
float4* SPTR_VEC = reinterpret_cast<float4*>(SPTR_HALF);
const float4* GPTR_VEC = reinterpret_cast<const float4*>(GPTR_HALF);
SPTR_VEC[0] = GPTR_VEC[0];
}
#else
cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard);
#endif
SPTR_HALF += 256; // Forward 512 Bytes
GPTR_HALF += 256; // Forward 512 Bytes
}
Expand Down Expand Up @@ -82,8 +94,15 @@ __device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ Shar
#pragma unroll
for (int i = 0; i < MaxIteration; i++) {
bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred;
#if __CUDA_ARCH__ == 750
if (AsyncCopyPred) {
float4* SharedPtrVec = reinterpret_cast<float4*>(&(*SharedPTR)[line_offset]);
const float4* GlobalPtrVec = reinterpret_cast<const float4*>(GlobalPTR);
SharedPtrVec[0] = GlobalPtrVec[0];
}
#else
cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
//
#endif
GlobalPTR += NumOfGroups * GlobalStride;
SharedPTR += NumOfGroups;
}
Expand Down
1 change: 1 addition & 0 deletions torchao/dtypes/floatx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape
- Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations.
- Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1.
- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 for some microbenchmark results.
- FP6 is supported for >=SM80 (Ampere generation) as well as SM75 (Turing generation) GPUs. However, SM75 support requires manual compilation of the C++/CUDA extensions (see the installation instructions in the [README](https://github.com/pytorch/ao/blob/main/README.md#installation) for details).

## End-to-End benchmarks

Expand Down
Loading