Skip to content

Commit

Permalink
Handle FP6 + SM75 + N>=64 edge case
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasvanderwerff committed Sep 26, 2024
1 parent 0a4d70e commit 650ba03
Showing 1 changed file with 54 additions and 25 deletions.
79 changes: 54 additions & 25 deletions torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@
#include <stdio.h>
#include <assert.h>

inline bool isSM75GPU() {
int device;
cudaError_t err = cudaGetDevice(&device);
if (err != cudaSuccess) {
return false;
}

cudaDeviceProp props;
err = cudaGetDeviceProperties(&props, device);
if (err != cudaSuccess) {
return false;
}

return (props.major == 7) && (props.minor == 5);
}

template<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA>
static void Kernel_Ex(cudaStream_t stream,
const uint4 *Weight,
Expand Down Expand Up @@ -80,38 +96,51 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream,
if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128;
if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128;

if (Split_K == 1) {
switch (N_PowerOf2) {
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
if (isSM75GPU() && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) {
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
if (Split_K == 1) {
Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);
} else {
Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);
}
}
else {
switch (N_PowerOf2) {
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
} else {
if (Split_K == 1) {
switch (N_PowerOf2) {
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
}
}
else {
switch (N_PowerOf2) {
case 8: Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 16: Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
}
}
}

if (Split_K != 1) {
// Reduction for SplitK
dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1);
dim3 BlockDim(WARP_SIZE, 1, 1);
SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(C, Reduction_Workspace, M_Global, N_Global, Split_K);
}

return cudaGetLastError();
}

Expand Down

0 comments on commit 650ba03

Please sign in to comment.