From ce8fd7d286ea97f2408a8b89d8bfb90be1d161d2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 08:53:54 +0800 Subject: [PATCH 01/31] first update from upstream --- torchao/csrc/cuda/fp6_llm/configs.h | 24 +-- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 56 +++---- torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 141 ++++++++++-------- torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 42 +----- torchao/csrc/cuda/fp6_llm/utils_core.cuh | 136 +++++------------ .../cuda/fp6_llm/utils_parallel_dequant.cuh | 122 +++++++-------- 6 files changed, 202 insertions(+), 319 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/configs.h b/torchao/csrc/cuda/fp6_llm/configs.h index 0a642fc805..60f6745048 100644 --- a/torchao/csrc/cuda/fp6_llm/configs.h +++ b/torchao/csrc/cuda/fp6_llm/configs.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/configs.h +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/configs.h #ifndef CONFIGS_H #define CONFIGS_H @@ -63,28 +63,6 @@ struct TilingConfig { static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 }; -/************************ General Config for FP6-LLM **********************/ -#define WEIGHT_FRAG1_BIT_WIDTH 2 -#define WEIGHT_FRAG2_BIT_WIDTH 4 -#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH+WEIGHT_FRAG2_BIT_WIDTH) // 6 -//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // QuantGroupSize: 4*64 = 256 -/*************************** 64*64 Weghts of A WARP *************************/ -#define WEIGHT_PER_UNIT (WARP_M*WARP_K) // 64*64 -#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/8) // 1024 Bytes #doubleBuffer not takedn into consideration -#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/8) // 2048 Bytes #doubleBuffer not takedn into consideration -#define SMEM_SIZE_A1_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A1*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. -#define SMEM_SIZE_A2_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A2*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. -/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/ -#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/128) // 64 -#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/128) // 128 -/******************** Register Allocation For QUANTIZED DATA ******************/ -#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT/WARP_SIZE) // 128 -#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*2) // 8 -#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*4) // 16 -/******************** Register Allocation For QUANT Scales ******************/ -#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers -#define WARP_REG_QUANT_SCALE_DISTRIBUTED 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for each thread - #endif // CONFIGS_H diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 8db5d44303..48194e499a 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/fp6_linear.cu +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" @@ -20,7 +20,7 @@ #include #include -template +template static void Kernel_Ex(cudaStream_t stream, const uint4 *Weight, const half *Scales, @@ -37,8 +37,8 @@ static void Kernel_Ex(cudaStream_t stream, printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); #endif - static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE, TilingConfig::SMEM_SIZE_C_TILE); - cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); + static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); size_t dimN = (N_Global-1) / TilingConfig::TILE_N + 1; size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; dim3 GridDim(dimN, dimM, 1); @@ -49,14 +49,12 @@ static void Kernel_Ex(cudaStream_t stream, GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); printf("\n"); #endif - QUANT_GEMM_Kernel<<>> + QUANT_GEMM_Kernel<<>> (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } -/* - * - */ -cudaError_t fp6_linear_kernel(cudaStream_t stream, +template +cudaError_t fpx_linear_kernel(cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, @@ -82,30 +80,30 @@ cudaError_t fp6_linear_kernel(cudaStream_t stream, if (Split_K == 1) { switch (N_PowerOf2) { - case 8: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 8: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } - Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } } else { switch (N_PowerOf2) { - case 8: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 8: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } - Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } // Reduction for SplitK dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); @@ -136,7 +134,8 @@ After Equivalent transformation : trans(Out) = W * trans(In). Note that we [Outputs] _out_feats: tensor of shape [B, OC]; // half */ -torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, +template +torch::Tensor fpx_linear_forward_cuda(torch::Tensor _in_feats, torch::Tensor _weights, torch::Tensor _scales, int64_t splitK=1) @@ -163,22 +162,13 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - fp6_linear_kernel(0, // Using default stream here. - weight, - scales, - in_feats, - out_feats, - M, - N, - K, - Reduction_Workspace, - splitK); + fpx_linear_kernel(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); return _out_feats; } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp6_llm_linear", &fp6_linear_forward_cuda); + m.impl("torchao::fp6_llm_linear", &fpx_linear_forward_cuda<3, 2>); } } // namespace torchao diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index ed11fc8517..f2c137828d 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -12,36 +12,59 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh #include "configs.h" #include "utils_gmem.cuh" #include "utils_core.cuh" +/************************** Bitwidth of Weight Segments ************************/ +#define BIT_WIDTH_1 1 +#define BIT_WIDTH_2 2 +#define BIT_WIDTH_4 4 +/*************************** 64*64 Weghts of Weight Matrix *********************/ +#define WEIGHT_PER_WARP (WARP_M*WARP_K) // 64*64 = 4096 +#define SMEM_SIZE_PER_WARP_1BIT (WEIGHT_PER_WARP*BIT_WIDTH_1/8) // 512 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_WARP_2BIT (WEIGHT_PER_WARP*BIT_WIDTH_2/8) // 1024 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_WARP_4BIT (WEIGHT_PER_WARP*BIT_WIDTH_4/8) // 2048 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_TB_1BIT (SMEM_SIZE_PER_WARP_1BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 6 KB; double buffer for 2-level pipeline A= 4 KB. +#define SMEM_SIZE_PER_TB_2BIT (SMEM_SIZE_PER_WARP_2BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. +#define SMEM_SIZE_PER_TB_4BIT (SMEM_SIZE_PER_WARP_4BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. +#define SMEM_SIZE_PER_TB_A_TILE (SMEM_SIZE_PER_TB_1BIT+SMEM_SIZE_PER_TB_2BIT+SMEM_SIZE_PER_TB_4BIT) // used in fp6_linear.cu, Kernel_Ex(). +/******************** Gloabl Memory Layout For QUANTIZED DATA *******************/ +#define NUM_INT4_PER_WARP_1BIT (WEIGHT_PER_WARP*BIT_WIDTH_1/128) // 32 +#define NUM_INT4_PER_WARP_2BIT (WEIGHT_PER_WARP*BIT_WIDTH_2/128) // 64 +#define NUM_INT4_PER_WARP_4BIT (WEIGHT_PER_WARP*BIT_WIDTH_4/128) // 128 + /* * C = A*B * A: row major with ahead-of-time layout transformation, FP6 * B: col major, FP16 * C: col major, FP16 */ - template + template __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, - int Split_K) + int Split_K) { #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M)); #endif - // 2+4 weight split - const uint4* Weight1 = Weight; - const uint4* Weight2 = Weight1 + M_Global*K_Global*2/128; + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + const uint4* Weight_1bit = Weight; + const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0); + const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0); // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned extern __shared__ __align__(128) half smem[]; - half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + (SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE)/2 ); // Dynamic shared memory for FP16 B tiles + half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes // Thread Block Mapping, considering SplitK const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); @@ -54,38 +77,48 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const size_t AverageNumBlock_K = NumBlock_K/Split_K; const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; size_t NumIter = AverageNumBlock_K; - if(BatchID(smem); - uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR+SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers + uint32_t* AFrag_1BIT_SPTR = reinterpret_cast(smem); + uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT/4; + uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR + SMEM_SIZE_PER_TB_2BIT/4; // 8 buffers including double buffers, 12 for trible buffers // StartSPTR for each WARP - AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4; - AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4; + AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT/4; + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT/4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT/4; // Pre-fetch of A tile for(int i=0; i(AFrag_2BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4, WARP_StartGPTR_A1); - CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4, WARP_StartGPTR_A2); - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; + if(USE_SEG_1BIT) CopyFromGlobalToShared_A(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT); + if(USE_SEG_2BIT) CopyFromGlobalToShared_A(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT); + if(USE_SEG_4BIT) CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT); + WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; + WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; + WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; } // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; @@ -100,10 +133,8 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block -#ifdef PIPELINE_LEVEL_SMEM uint32_t a [NumRegSets_a * PIPELINE_LEVEL_SMEM][4]; // double/Trible buffer is used // Registers to store decompressed FP6 uint32_t b [NumRegSets_b * PIPELINE_LEVEL_SMEM][4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) -#endif float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; for(int i=0; i(a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); -#endif + initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #pragma unroll(1) for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) { // Trible-Buffer for A Tile - uint32_t* __restrict__ read_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ read_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 -#ifdef PIPELINE_LEVEL_SMEM - uint32_t* __restrict__ read2_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; - uint32_t* __restrict__ read2_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; -#endif - uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read2_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; + uint32_t* __restrict__ read2_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; + uint32_t* __restrict__ read2_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; + uint32_t* __restrict__ write_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 // Trible-Buffer for B Tile // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR. half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; -#ifdef PIPELINE_LEVEL_SMEM half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; -#endif half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; // bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; - // Copying A tile from Global to Register, Bypassing L1, using double-buffer - CopyFromGlobalToShared_A(write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); - CopyFromGlobalToShared_A(write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + if(USE_SEG_1BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); + if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); + if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); // copying B tile from GlobalMemory to SharedMemory CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); cp_async_group_commit(); - #ifdef PIPELINE_LEVEL_SMEM - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); // Barriers and Synchronizations cp_async_wait_group(); __syncthreads(); - core_mma_slice(c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); - // Updating global PTRs - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 - BTile_GPTR += TilingConfig::TILE_K; - #else - PipelinedCoreLoop(c, read_SPTR, read_SPTR_Frag1, read_SPTR_Frag2, Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); // Updating global PTRs - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; // 2KB/16=128 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 BTile_GPTR += TilingConfig::TILE_K; - // Barriers and Synchronizations - cp_async_wait_group(); - __syncthreads(); - #endif } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index bafdd0b4e3..1658352ee5 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. @@ -39,7 +39,6 @@ // MODIFICATION NOTE: to support MSVC // - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4] // - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR) -#ifdef PIPELINE_LEVEL_SMEM template __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4], half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], @@ -75,45 +74,6 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[ } } } -#else -// Debug: Whether ldmatrix.trans is required??? -// B is in column-major -template -__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], - half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - int k_offset) { - #ifdef DEBUG_MODE - static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); - #endif - - const int warpId = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; - int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory - #ifdef DEBUG_MODE - assert( warp_start_col==0 ); - #endif - - int col = (lane_id%8) + (lane_id/16)*8; - int row = (lane_id%16) / 8 * 8; - uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][k_offset + row])); - if(TilingConfig::WARP_COL_MMA_TENSORS==1) { - asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" - : "=r"(Reg[0][0]), "=r"(Reg[0][1]) - : "r"(smem_local_ptr)); - } - else { - #pragma unroll - for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) - { - asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) - : "r"(smem_local_ptr)); - smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); - } - } -} -#endif // MODIFICATION NOTE: to support MSVC, the function signature is changed from // MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b). diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh index 07e37d85bc..7a6cd36a46 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_core.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh #ifndef UTILS_CORE_CUH #define UTILS_CORE_CUH @@ -24,7 +24,6 @@ #include "utils_parallel_dequant.cuh" -#ifdef PIPELINE_LEVEL_SMEM template __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE); @@ -36,35 +35,50 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], uint32_t (*b)[4], - uint32_t* __restrict__ A1_SPTR_read, - uint32_t* __restrict__ A2_SPTR_read, + uint32_t* __restrict__ A_1BIT_SPTR_read, + uint32_t* __restrict__ A_2BIT_SPTR_read, + uint32_t* __restrict__ A_4BIT_SPTR_read, half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales) { + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; // Writing registers // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2]; // NO double buffer - uint32_t a_2[4]; // NO double buffer - CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, 0); - CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, 0); - Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + uint32_t a_1bit[1]; // NO double buffer + uint32_t a_2bit[2]; // NO double buffer + uint32_t a_4bit[4]; // NO double buffer + if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1BIT_SPTR_read, 0); + if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2BIT_SPTR_read, 0); + if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4BIT_SPTR_read, 0); + Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], uint32_t (*b)[4], - uint32_t* __restrict__ A1_SPTR_read, - uint32_t* __restrict__ A2_SPTR_read, + uint32_t* __restrict__ A_1bit_SPTR_read, + uint32_t* __restrict__ A_2bit_SPTR_read, + uint32_t* __restrict__ A_4bit_SPTR_read, half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales, int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching { + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + #ifdef DEBUG_MODE assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block #endif @@ -94,100 +108,18 @@ __device__ __forceinline__ void core_mma_slice(float c[][REG } } } - // Writing registers // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2]; // NO double buffer - uint32_t a_2[4]; // NO double buffer - CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, slice_id); - CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, slice_id); - Dequant_32FP6_4Way(a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + uint32_t a_1bit[1]; // NO double buffer + uint32_t a_2bit[2]; // NO double buffer + uint32_t a_4bit[4]; // NO double buffer + if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1bit_SPTR_read, slice_id); + if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2bit_SPTR_read, slice_id); + if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4bit_SPTR_read, slice_id); + Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers } -#else -// Old version with naive pipeline design -template -__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) { - int lane_id = threadIdx.x % WARP_SIZE; - #pragma unroll - for(int i=0; i -__device__ __forceinline__ void PipelinedCoreLoop(float c[][REG_PER_THREAD_C_TENSOR_16_16], - half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - uint32_t* __restrict__ read_SPTR_Frag1, - uint32_t* __restrict__ read_SPTR_Frag2, - uint32_t* RPTR_Scales) -{ - #ifdef DEBUG_MODE - assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block - #endif - const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block - const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block - - // Reigsters to store FP32 results - uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); - // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2*2]; // double buffer is used - uint32_t a_2[4*2]; // double buffer is used - // Registers to store decompressed FP6 - uint32_t a [NumRegSets_a * 1][4]; // No double buffer - // Register to store FP16 B matrix (a slice) - uint32_t b [NumRegSets_b * 2][4]; // double buffer is used - - // Overlapped Smem and TC pipeline: pre-loading from shared to registers - CopyFromSharedToRegister_AFrag<2> (a_1, read_SPTR_Frag1); - CopyFromSharedToRegister_AFrag<4> (a_2, read_SPTR_Frag2); - B_FromSharedToReg (b, read_SPTR, 0); - - #pragma unroll - for (int k = 0; k < WARP_K_MMA_TENSORS; k++) { - uint32_t (*b_read)[4] = b; - uint32_t (*b_write)[4] = b; - uint32_t *a_1_read = a_1; - uint32_t *a_1_write = a_1; - uint32_t *a_2_read = a_2; - uint32_t *a_2_write = a_2; - if(k%2==0) { - b_write += NumRegSets_b; - a_1_write += 2; - a_2_write += 4; - } - else { - b_read += NumRegSets_b; - a_1_read += 2; - a_2_read += 4; - } - // data loading - if (k + 1 < WARP_K_MMA_TENSORS) { - // updating SPTR for fragment1 and fragment2 - read_SPTR_Frag1 += 2*WARP_SIZE; - read_SPTR_Frag2 += 4*WARP_SIZE; - CopyFromSharedToRegister_AFrag<2>(a_1_write, read_SPTR_Frag1); - CopyFromSharedToRegister_AFrag<4>(a_2_write, read_SPTR_Frag2); - B_FromSharedToReg(b_write, read_SPTR, (k+1)*MMA_16); - } - // SIMT Dequant + Tensor Core computations - Dequant_32FP6_4Way(a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, dequantizing a slice each time - #pragma unroll - for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { - if(TilingConfig::WARP_COL_MMA_TENSORS==1) - MMA_FP16_M16N8K16( c_uint_ptr[i], a[i], b_read[0] ); - else { - #pragma unroll - for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j] ); - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a[i], b_read[j] + 2 ); // c+4; b+2 - } - } - } - } -} -#endif // #ifdef PIPELINE_LEVEL_SMEM - template __device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], float c[][REG_PER_THREAD_C_TENSOR_16_16]) diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 48b0f968bb..4c8c39603e 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh // To support MSVC, all instances of u_int32_t are changed to uint32_t. #ifndef UTILS_PARALLELDEQUANT_CUH @@ -27,86 +27,90 @@ * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t *R1, uint32_t *R2) { - *R2 = *R1 & 0x80808080; - *R1 = *R1 >> 2; - *R1 = *R1 & 0x1f1f1f1f; - *R2 = *R2 | *R1; - *R1 = *R2 & 0x9f009f00; - *R2 = *R2 & 0x009f009f; - *R2 = *R2 << 8; -} - -/* - * Input: R1 - * Outputs: R1, R2 - * Note: Simplified Exponent calculation is NOT applied. - */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t *R1, uint32_t *R2) { - //*R2 = *R1 & 0x80808080; - *R2 = *R1 & 0xc0c0c0c0; - *R1 = *R1 >> 2; - //*R1 = *R1 & 0x1f1f1f1f; - *R1 = *R1 & 0x0f0f0f0f; - *R2 = *R2 | *R1; +template +__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) { + // + constexpr int RIGHT_SHIFT = 5 - EXPONENT; + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | MASK3 >> 16; // - //*R1 = *R2 & 0x9f009f00; - //*R2 = *R2 & 0x009f009f; - *R1 = *R2 & 0xcf00cf00; - if( !(*R1 & 0x40000000) && (*R1 & 0x0c000000) ) *R1 = *R1 | 0x30000000; - if( !(*R1 & 0x00004000) && (*R1 & 0x00000c00) ) *R1 = *R1 | 0x00003000; - *R2 = *R2 & 0x00cf00cf; - if( !(*R2 & 0x00400000) && (*R2 & 0x000c0000) ) *R2 = *R2 | 0x00300000; - if( !(*R2 & 0x00000040) && (*R2 & 0x0000000c) ) *R2 = *R2 | 0x00000030; + *Out1 = *In & 0x80008000; + *Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT; // - *R2 = *R2 << 8; - //*R1 = 0x3c003c00; - //*R2 = 0x3c003c00; + *In = (*In) << 8; + *Out2 = *In & 0x80008000; + *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT; } +template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { + constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1)); + constexpr int BIAS = int(1) << BIAS_OFFSET; + // half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; uint32_t output; half* output_half_ptr = reinterpret_cast(&output); - output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(4096.0f)), Scale); - output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(4096.0f)), Scale); + output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); + output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); return output; } // MODIFICATION NOTE: to support MSVC // - u_int32_t __restrict__ Reg[][4] is changed to below. -// - u_int32_t __restrict__ *read_RPTR_Frag1 is changed to below. similarly for read_RPTR_Frag2 +// - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for read_RPTR_2bit and read_RPTR_4bit +template __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4], - uint32_t * __restrict__ read_RPTR_Frag1, - uint32_t * __restrict__ read_RPTR_Frag2, + uint32_t * __restrict__ read_RPTR_1bit, + uint32_t * __restrict__ read_RPTR_2bit, + uint32_t * __restrict__ read_RPTR_4bit, uint32_t * Scales) { - uint32_t *OutputRegs = reinterpret_cast (Reg); - uint32_t *Frag1_PTR = read_RPTR_Frag1; - uint32_t *Frag2_PTR = read_RPTR_Frag2; - half *Scale_RPTR = reinterpret_cast(Scales); - uint32_t Packed_FP6 = 0; - uint32_t tmp = 0; + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + // + uint32_t *OutputRegs = reinterpret_cast (Reg); + uint32_t *Frag_PTR_1bit = read_RPTR_1bit; + uint32_t *Frag_PTR_2bit = read_RPTR_2bit; + uint32_t *Frag_PTR_4bit = read_RPTR_4bit; + half *Scale_RPTR = reinterpret_cast(Scales); // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) for(int i=0; i<8; i++) { - // Frag1 - Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; - if(i%4==3) Frag1_PTR++; - else (*Frag1_PTR) = (*Frag1_PTR) << 2; - // Frag2 - tmp = (*Frag2_PTR) & 0xf0f0f0f0; - tmp = tmp >> 2; - if(i%2==1) Frag2_PTR++; - else (*Frag2_PTR) = (*Frag2_PTR) << 4; - // Packed_FP6 - Packed_FP6 = Packed_FP6 | tmp; + uint32_t Packed_FP6 = 0; + uint32_t tmp = 0; + // 1bit Frag + if(USE_SEG_1BIT) { + tmp = (*Frag_PTR_1bit) & 0x80808080; + Packed_FP6 |= tmp >> (BIT_WIDTH & 0); + if(i%8==7) Frag_PTR_1bit++; + else (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1; + } + // 2bit Frag + if(USE_SEG_2BIT) { + tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0; + Packed_FP6 |= tmp >> (BIT_WIDTH & 1); + if(i%4==3) Frag_PTR_2bit++; + else (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2; + } + // 4bit Frag2 + if(USE_SEG_4BIT) { + tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0; + Packed_FP6 |= tmp >> (BIT_WIDTH & 3); + if(i%2==1) Frag_PTR_4bit++; + else (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; + } // - FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); + uint32_t out1, out2; + FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); // - *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0] ); // Muliply FP16 scales + *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16 scales OutputRegs += 1; - *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16 scales OutputRegs += 1; // Updating offset for FP16 scales for every two iterations if(i%2==1) Scale_RPTR += 2; From d8bd7b608eb9fef2a5482fc10a01a63df865bb64 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 10:00:43 +0800 Subject: [PATCH 02/31] add some primitives to support fp5 --- test/prototype/test_fp6_llm.py | 4 +- torchao/prototype/fp6_llm/fp6_llm.py | 106 ++++++++++++++------------- 2 files changed, 58 insertions(+), 52 deletions(-) diff --git a/test/prototype/test_fp6_llm.py b/test/prototype/test_fp6_llm.py index 9ee3faae4a..fdb228023e 100644 --- a/test/prototype/test_fp6_llm.py +++ b/test/prototype/test_fp6_llm.py @@ -10,7 +10,7 @@ from torchao.prototype.fp6_llm.fp6_llm import ( to_tc_float6_e3m2, from_tc_float6_e3m2, - _to_tc_float6_e3m2_ref, + _to_tc_fpx, Fp6LlmLinear, convert_fp6_llm, ) @@ -25,7 +25,7 @@ class TestFp6LlmLinear(TestCase): def test_to_tc_float6_e3m2_correctness(self, device): x = torch.randn(256, 64, device=device) - expected = _to_tc_float6_e3m2_ref(x) + expected = _to_tc_fpx(x, 3, 2) actual = to_tc_float6_e3m2(x) torch.testing.assert_close(actual, expected) diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index 570ea13546..4c4f3d985e 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -1,72 +1,78 @@ +from functools import reduce import math from typing import List, Optional, Tuple import torch from torch import nn, Tensor -from torchao.prototype.mx_formats.custom_cast import f32_to_f6_e3m2_unpacked, f6_e3m2_unpacked_to_f32 +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 from torchao.prototype.mx_formats.constants import F6_E3M2_MAX from torchao.ops import fp6_llm_linear -def _pack_2bit(x: Tensor) -> Tensor: - return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] +def _pack(x: Tensor, n_bits: int) -> Tensor: + return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) -def _unpack_2bit(x: Tensor) -> Tensor: - return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) +def _unpack(x: Tensor, n_bits: int) -> Tensor: + return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) -def _pack_4bit(x: Tensor) -> Tensor: - return (x[..., ::2] << 4) | x[..., 1::2] +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 +def _bit_interleave(x: Tensor, n_bits: int) -> Tensor: + # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8 + # thus, we need to reverse byte order within a uint32 word. + x = x.reshape(-1, 4).flip(1) + x = _unpack(x, n_bits) + x = x.view(-1, 4 * (8 // n_bits)) -def _unpack_4bit(x: Tensor) -> Tensor: - return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) + bit_order = { + 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, + 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], + 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], + 4: [1, 5, 3, 7, 0, 4, 2, 6] + }[n_bits] + x = x[:, bit_order] + + x = _pack(x, n_bits) + + # reverse byte order within a uint32 word again. + x = x.reshape(-1, 4).flip(1) + return x.flatten() # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing -# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def _to_tc_float6_e3m2_ref(tensor: Tensor) -> Tensor: +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h +def _to_tc_fpx(tensor: Tensor, n_ebits: int, n_mbits: int) -> Tensor: assert tensor.ndim == 2 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fp6 = f32_to_f6_e3m2_unpacked(tensor.float()) + tensor_fpx = _f32_to_fpx_unpacked(tensor.float(), n_ebits, n_mbits) # Pass 1 from original code - tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) - tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) - tensor_fp6 = tensor_fp6.permute(1, 0, 2) - tensor_fp6 = tensor_fp6.flatten() - - tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11) - tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111) - - # Pass 2 from original code - tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) - tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) - - # Pass 3 from original code - # BitInterleaving_2bit - # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 - # while we still unpack/pack the values from/to uint8 - tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) - tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] - tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - tensor_2bit = _pack_2bit(tensor_2bit).view(-1) - - # BitInterleaving_4bit - # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 - # while we still unpack/pack the values from/to uint8 - tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) - tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] - tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - tensor_4bit = _pack_4bit(tensor_4bit).view(-1) + tensor_fpx = tensor_fpx.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor_fpx = tensor_fpx.permute(0, 4, 1, 5, 2, 3, 6) + tensor_fpx = tensor_fpx.reshape(-1, 32, 2) + tensor_fpx = tensor_fpx.permute(1, 0, 2) + tensor_fpx = tensor_fpx.flatten() - return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) + n_bits = 1 + n_ebits + n_mbits + n_used = 0 + fragments = [] + + for y in [1, 2, 4]: + if n_bits & y: + mask = (1 << y) - 1 + tensor_ybit = (tensor_fpx >> (n_bits - n_used - y)) & mask + tensor_ybit = _pack(tensor_ybit, y) + + tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code + tensor_ybit = _bit_interleave(tensor_ybit, y) # Pass 3 from original code + fragments.append(tensor_ybit) + n_used += y + + return torch.cat(fragments, dim=0).view(M, -1) # more optimized version of _to_tc_float6_e3m2_original() by merging ops @@ -76,17 +82,17 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fp6 = f32_to_f6_e3m2_unpacked(tensor.float()) + tensor_fp6 = _f32_to_fpx_unpacked(tensor.float(), 3, 2) tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) tensor_fp6 = tensor_fp6.flip(3) tensor_2bit = (tensor_fp6 >> 4) & 0b11 tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) - tensor_2bit = _pack_2bit(tensor_2bit.flatten()) + tensor_2bit = _pack(tensor_2bit.flatten(), 2) tensor_4bit = tensor_fp6 & 0b1111 tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) - tensor_4bit = _pack_4bit(tensor_4bit.flatten()) + tensor_4bit = _pack(tensor_4bit.flatten(), 4) return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) @@ -109,17 +115,17 @@ def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> T tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) - tensor_2bit = _unpack_2bit(tensor_2bit) + tensor_2bit = _unpack(tensor_2bit, 2) tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) - tensor_4bit = _unpack_4bit(tensor_4bit) + tensor_4bit = _unpack(tensor_4bit, 4) tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) tensor_fp6 = (tensor_2bit << 4) | tensor_4bit tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) - return f6_e3m2_unpacked_to_f32(tensor_fp6).to(dtype) + return _fpx_unpacked_to_f32(tensor_fp6, 3, 2).to(dtype) # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py From 129adff4713ed23c175321d19ee98f3effeaf076 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 10:19:41 +0800 Subject: [PATCH 03/31] binding for ExMy --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 62 +++++++++++++++++++++---- torchao/csrc/fp6_llm.cpp | 1 + 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 48194e499a..e09deed9df 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -120,7 +120,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, namespace torchao { /* -Computes FP6-FP16 GEMM (PyTorch interface). +Computes FPx-FP16 GEMM (PyTorch interface). [Mathmatical Formula] Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. @@ -128,17 +128,19 @@ After Equivalent transformation : trans(Out) = W * trans(In). Note that we [Inputs] _in_feats: tensor of shape [B, IC]; // half - _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + _weights: int tensor of shape [OC, IC // 32 * x]; // x INT32 words contains 32 FPx weights. _scales: tensor of shape [OC]; // half splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half */ -template -torch::Tensor fpx_linear_forward_cuda(torch::Tensor _in_feats, - torch::Tensor _weights, - torch::Tensor _scales, - int64_t splitK=1) +torch::Tensor fp_eXmY_linear_forward_cuda( + int64_t EXPONENT, + int64_t MANTISSA, + torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int64_t splitK=1) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); @@ -161,14 +163,54 @@ torch::Tensor fpx_linear_forward_cuda(torch::Tensor _in_feats, options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - - fpx_linear_kernel(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + if (EXPONENT == 3 && MANTISSA == 2) + fpx_linear_kernel<3, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + // experimental + else if (EXPONENT == 2 && MANTISSA == 3) + fpx_linear_kernel<2, 3>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + else if (EXPONENT == 2 && MANTISSA == 2) + fpx_linear_kernel<2, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + // experimental + else if (EXPONENT == 3 && MANTISSA == 1) + fpx_linear_kernel<3, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + else + TORCH_CHECK(false, "Only FP6 E3M2 and FP5 E2M2 are supported"); return _out_feats; } +/* +Computes FP6-FP16 GEMM (PyTorch interface). + +[Mathmatical Formula] +Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. +After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. + +[Inputs] + _in_feats: tensor of shape [B, IC]; // half + _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + _scales: tensor of shape [OC]; // half + splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. +[Outputs] + _out_feats: tensor of shape [B, OC]; // half +*/ +torch::Tensor fp6_linear_forward_cuda( + torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int64_t splitK=1) +{ + return fp_eXmY_linear_forward_cuda(3, 2, _in_feats, _weights, _scales, splitK); +} + TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp6_llm_linear", &fpx_linear_forward_cuda<3, 2>); + m.impl("torchao::fp6_llm_linear", &fp6_linear_forward_cuda); + m.impl("torchao::quant_llm_linear", &fp_eXmY_linear_forward_cuda); } } // namespace torchao diff --git a/torchao/csrc/fp6_llm.cpp b/torchao/csrc/fp6_llm.cpp index bd787385c0..00c603b1f2 100644 --- a/torchao/csrc/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm.cpp @@ -5,4 +5,5 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def("fp6_llm_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); + m.def("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); } From 64c6cee59095c1c4bdc15f83a39ae9fa8ab5826d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 10:26:19 +0800 Subject: [PATCH 04/31] add QuantLlmLinear --- torchao/prototype/fp6_llm/fp6_llm.py | 84 ++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index 4c4f3d985e..6641eb505a 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -43,12 +43,12 @@ def _bit_interleave(x: Tensor, n_bits: int) -> Tensor: # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h -def _to_tc_fpx(tensor: Tensor, n_ebits: int, n_mbits: int) -> Tensor: +def _to_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tensor: assert tensor.ndim == 2 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fpx = _f32_to_fpx_unpacked(tensor.float(), n_ebits, n_mbits) + tensor_fpx = _f32_to_fpx_unpacked(tensor.float(), ebits, mbits) # Pass 1 from original code tensor_fpx = tensor_fpx.view(M // 64, 4, 2, 8, N // 16, 2, 8) @@ -57,7 +57,7 @@ def _to_tc_fpx(tensor: Tensor, n_ebits: int, n_mbits: int) -> Tensor: tensor_fpx = tensor_fpx.permute(1, 0, 2) tensor_fpx = tensor_fpx.flatten() - n_bits = 1 + n_ebits + n_mbits + n_bits = 1 + ebits + mbits n_used = 0 fragments = [] @@ -307,7 +307,83 @@ def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[List[str]] = None, if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): if (child.in_features % 64 == 0) and (child.out_features % 256 == 0): - new_child = Fp6LlmLinear.from_float(child) + new_child = Fp6LlmLinear.from_float(child) setattr(model, name, new_child) else: convert_fp6_llm(child, skip_fqn_list, new_fqn) + + +class QuantLlmLinear(nn.Module): + """Quant-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. + """ + + def __init__( + self, + ebits: int, + mbits: int, + weight: Tensor, + scales: Tensor, + bias: Optional[Tensor] = None, + ) -> None: + super().__init__() + self.register_buffer("weight", weight.view(torch.int32)) + self.register_buffer("scales", scales) + self.register_buffer("bias", bias) + self.out_features = weight.shape[0] + self.in_features = weight.shape[1] // 3 * 4 + self.ebits = ebits + self.mbits = mbits + + def forward(self, x: Tensor) -> Tensor: + splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features) + out = torch.ops.torchao.quant_llm_linear.default( + self.ebits, + self.mbits, + x.view(-1, self.in_features).half(), + self.weight, + self.scales, + splitK=splitK, + ) + if self.bias is not None: + out = out + self.bias + return out.view(*x.shape[:-1], self.out_features).to(x.dtype) + + @staticmethod + def get_split_k(bsize: int, out_dim: int) -> int: + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + + @classmethod + def from_float(cls, linear: nn.Linear, ebits: int, mbits: int): + assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) + + fp6_weight, scale = _to_tc_fpx(linear.weight.detach(), ebits, mbits) + bias = linear.bias.detach().half() if linear.bias is not None else None + return cls(ebits, mbits, fp6_weight, scale, bias) + + def extra_repr(self) -> str: + return ( + f'in_features={self.in_features}' + f', out_features={self.out_features}' + f', bias={self.bias is not None}' + f', ebits={self.ebits}' + f', mbits={self.mbits}' + ) + + +def convert_quant_llm( + model: nn.Module, + ebits: int, + mbits: int, + skip_fqn_list: Optional[List[str]] = None, + cur_fqn: str = "", +) -> None: + for name, child in model.named_children(): + new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" + + if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): + if (child.in_features % 64 == 0) and (child.out_features % 256 == 0): + new_child = QuantLlmLinear.from_float(child, ebits, mbits) + setattr(model, name, new_child) + else: + convert_quant_llm(child, ebits, mbits, skip_fqn_list, new_fqn) From 38ad773be38fe5a113d5f1f7a8a39532d925df41 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 10:52:58 +0800 Subject: [PATCH 05/31] fix --- torchao/ops.py | 25 +++++++++++++++++++++++++ torchao/prototype/fp6_llm/fp6_llm.py | 19 ++++++++++++++----- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/torchao/ops.py b/torchao/ops.py index 25cbfb5656..3958cda9ef 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -43,3 +43,28 @@ def _(_in_feats, _weights, _scales, splitK = 1): torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) + + +def quant_llm_linear( + EXPONENT: int, + MANTISSA: int, + _in_feats: Tensor, + _weights: Tensor, + _scales: Tensor, + splitK: int = 1, +) -> Tensor: + """ + Quant-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. + + Arguments + EXPONENT: number of exponent bits + MANTISSA: number of mantissa bits + _in_feats: input activations in FP16 + _weights: packed FP6 weights. See :func:prepack_fp6_weight and :func:fp16_to_fp6 + _scales: scale + splitK: split K + + Returns + output of linear layer + """ + return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK) diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index 6641eb505a..f013e1dc51 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -4,9 +4,9 @@ import torch from torch import nn, Tensor -from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones from torchao.prototype.mx_formats.constants import F6_E3M2_MAX -from torchao.ops import fp6_llm_linear +from torchao.ops import fp6_llm_linear, quant_llm_linear def _pack(x: Tensor, n_bits: int) -> Tensor: @@ -75,6 +75,15 @@ def _to_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tensor: return torch.cat(fragments, dim=0).view(M, -1) +def _to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: + exp_bias = _n_ones(ebits - 1) + max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) + + scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal + tc_fpx_tensor = to_tc_float6_e3m2(tensor / scale.view(-1, 1)) + return tc_fpx_tensor, scale.half() + + # more optimized version of _to_tc_float6_e3m2_original() by merging ops # https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: @@ -336,7 +345,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor: splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features) - out = torch.ops.torchao.quant_llm_linear.default( + out = quant_llm_linear( self.ebits, self.mbits, x.view(-1, self.in_features).half(), @@ -357,9 +366,9 @@ def get_split_k(bsize: int, out_dim: int) -> int: def from_float(cls, linear: nn.Linear, ebits: int, mbits: int): assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) - fp6_weight, scale = _to_tc_fpx(linear.weight.detach(), ebits, mbits) + fpx_weight, scale = _to_scaled_tc_fpx(linear.weight.detach(), ebits, mbits) bias = linear.bias.detach().half() if linear.bias is not None else None - return cls(ebits, mbits, fp6_weight, scale, bias) + return cls(ebits, mbits, fpx_weight, scale, bias) def extra_repr(self) -> str: return ( From 057367edf81e34da7ee33a633ba56efb4e24b7a9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 11:02:00 +0800 Subject: [PATCH 06/31] update README --- torchao/prototype/fp6_llm/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/prototype/fp6_llm/README.md b/torchao/prototype/fp6_llm/README.md index 767785275b..5b34c16616 100644 --- a/torchao/prototype/fp6_llm/README.md +++ b/torchao/prototype/fp6_llm/README.md @@ -31,6 +31,8 @@ fp16_act = torch.randn(1, 512).cuda().half() outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024) ``` +**NOTE**: since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization. + ## TODO - [ ] Compile CUDA kernel for Windows From a6ed669f0b6a1ed852e51a0409741704d6dce675 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 11:02:37 +0800 Subject: [PATCH 07/31] update README --- torchao/prototype/fp6_llm/README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/fp6_llm/README.md b/torchao/prototype/fp6_llm/README.md index 5b34c16616..2d64d7313a 100644 --- a/torchao/prototype/fp6_llm/README.md +++ b/torchao/prototype/fp6_llm/README.md @@ -33,10 +33,9 @@ outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024) **NOTE**: since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization. -## TODO +## Benchmark results -- [ ] Compile CUDA kernel for Windows -- [ ] Merge FP5 from upstream +TODO ## Credits From 0d409a805b1da6ec95e5ee7b44ddb96eebdc82e9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 11:09:44 +0800 Subject: [PATCH 08/31] remove fp6_linear from C++ --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 25 ----------------- torchao/csrc/fp6_llm.cpp | 1 - torchao/ops.py | 37 +++++++++++++------------ torchao/prototype/fp6_llm/__init__.py | 2 +- 4 files changed, 20 insertions(+), 45 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index e09deed9df..73732bb1a8 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -184,32 +184,7 @@ torch::Tensor fp_eXmY_linear_forward_cuda( return _out_feats; } -/* -Computes FP6-FP16 GEMM (PyTorch interface). - -[Mathmatical Formula] -Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. -After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. - -[Inputs] - _in_feats: tensor of shape [B, IC]; // half - _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. - _scales: tensor of shape [OC]; // half - splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. -[Outputs] - _out_feats: tensor of shape [B, OC]; // half -*/ -torch::Tensor fp6_linear_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _weights, - torch::Tensor _scales, - int64_t splitK=1) -{ - return fp_eXmY_linear_forward_cuda(3, 2, _in_feats, _weights, _scales, splitK); -} - TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp6_llm_linear", &fp6_linear_forward_cuda); m.impl("torchao::quant_llm_linear", &fp_eXmY_linear_forward_cuda); } diff --git a/torchao/csrc/fp6_llm.cpp b/torchao/csrc/fp6_llm.cpp index 00c603b1f2..861cdbf6db 100644 --- a/torchao/csrc/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm.cpp @@ -4,6 +4,5 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); - m.def("fp6_llm_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 3958cda9ef..127f54192b 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -25,24 +25,7 @@ def fp6_llm_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: Returns output of linear layer """ - return torch.ops.torchao.fp6_llm_linear.default(_in_feats, _weights, _scales, splitK) - - -@register_custom_op("torchao::fp6_llm_linear") -def _(_in_feats, _weights, _scales, splitK = 1): - torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") - torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") - torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") - torch._check(_weights.dtype is torch.int32, lambda: f"weight must be INT32, got {_weights.dtype}") - torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") - torch._check(_scales.dtype is torch.float16, lambda: f"scale must be FP16, got {_scales.dtype}") - - BS, IC = _in_feats.shape - OC, _ = _weights.shape - torch._check(IC / 16 * 3 == _weights.shape[1], lambda: "Dimensions mismatched") - torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") - - return _in_feats.new_empty((BS, OC)) + return quant_llm_linear(3, 2, _in_feats, _weights, _scales, splitK) def quant_llm_linear( @@ -68,3 +51,21 @@ def quant_llm_linear( output of linear layer """ return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK) + + +@register_custom_op("torchao::quant_llm_linear") +def _(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK = 1): + torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") + torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") + torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") + torch._check(_weights.dtype is torch.int32, lambda: f"weight must be INT32, got {_weights.dtype}") + torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") + torch._check(_scales.dtype is torch.float16, lambda: f"scale must be FP16, got {_scales.dtype}") + + BS, IC = _in_feats.shape + OC, _ = _weights.shape + N_BITS = 1 + EXPONENT + MANTISSA + torch._check(IC // 32 * N_BITS == _weights.shape[1], lambda: "Dimensions mismatched") + torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") + + return _in_feats.new_empty((BS, OC)) diff --git a/torchao/prototype/fp6_llm/__init__.py b/torchao/prototype/fp6_llm/__init__.py index d1a46339bd..befdfbec40 100644 --- a/torchao/prototype/fp6_llm/__init__.py +++ b/torchao/prototype/fp6_llm/__init__.py @@ -1 +1 @@ -from .fp6_llm import Fp6LlmLinear, convert_fp6_llm, to_scaled_tc_float6_e3m2 +from .fp6_llm import Fp6LlmLinear, convert_fp6_llm, to_scaled_tc_float6_e3m2, convert_quant_llm From 0bd9ee101c9694f1d5831491d2aed2fd1cafd82f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 11:14:10 +0800 Subject: [PATCH 09/31] fix --- torchao/prototype/fp6_llm/fp6_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index f013e1dc51..290fd2ebd4 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -80,7 +80,7 @@ def _to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, T max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal - tc_fpx_tensor = to_tc_float6_e3m2(tensor / scale.view(-1, 1)) + tc_fpx_tensor = _to_tc_fpx(tensor / scale.view(-1, 1), ebits, mbits) return tc_fpx_tensor, scale.half() From 8fbd3d438952740f2807cdd7e9e6d059f6086174 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 11:15:20 +0800 Subject: [PATCH 10/31] fix --- torchao/prototype/fp6_llm/fp6_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index 290fd2ebd4..83d9d12051 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -339,7 +339,7 @@ def __init__( self.register_buffer("scales", scales) self.register_buffer("bias", bias) self.out_features = weight.shape[0] - self.in_features = weight.shape[1] // 3 * 4 + self.in_features = weight.shape[1] // (1 + ebits + mbits) * 8 self.ebits = ebits self.mbits = mbits From 5906eedaba65c83ec85367d153f3a002160553d6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 11:17:36 +0800 Subject: [PATCH 11/31] fix --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 73732bb1a8..5311d3b2b1 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -142,11 +142,12 @@ torch::Tensor fp_eXmY_linear_forward_cuda( torch::Tensor _scales, int64_t splitK=1) { + const int64_t NBITS = 1 + EXPONENT + MANTISSA; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); int num_out_channels = _weights.size(0); - TORCH_CHECK(num_in_channels%64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); - TORCH_CHECK((num_in_channels/16*3) == _weights.size(1)); // Making sure the K dimension is matched. + TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); + TORCH_CHECK((num_in_channels / 32 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. // int M = num_out_channels; int K = num_in_channels; From 3b008d564ebb080862aae2b875adb4c260e1236e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 11:26:47 +0800 Subject: [PATCH 12/31] update --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 5311d3b2b1..74532bd2c1 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -167,20 +167,19 @@ torch::Tensor fp_eXmY_linear_forward_cuda( if (EXPONENT == 3 && MANTISSA == 2) fpx_linear_kernel<3, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - - // experimental - else if (EXPONENT == 2 && MANTISSA == 3) - fpx_linear_kernel<2, 3>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 2 && MANTISSA == 2) fpx_linear_kernel<2, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // experimental + else if (EXPONENT == 2 && MANTISSA == 3) + fpx_linear_kernel<2, 3>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 3 && MANTISSA == 1) fpx_linear_kernel<3, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 1) + fpx_linear_kernel<2, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else - TORCH_CHECK(false, "Only FP6 E3M2 and FP5 E2M2 are supported"); + TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); return _out_feats; } From 9076e5873a89798d6cc82c72dac2fc6787e8aa2c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 12:31:50 +0800 Subject: [PATCH 13/31] add more experimental config --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 74532bd2c1..3f0c976e78 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -165,6 +165,7 @@ torch::Tensor fp_eXmY_linear_forward_cuda( at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + // officially supported in Quant-LLM if (EXPONENT == 3 && MANTISSA == 2) fpx_linear_kernel<3, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 2 && MANTISSA == 2) @@ -177,6 +178,10 @@ torch::Tensor fp_eXmY_linear_forward_cuda( fpx_linear_kernel<3, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 2 && MANTISSA == 1) fpx_linear_kernel<2, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 3 && MANTISSA == 0) + fpx_linear_kernel<3, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 0) + fpx_linear_kernel<2, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); From 442e9c53c7c377206eeefcee57877f1b2bea679c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 13:31:27 +0800 Subject: [PATCH 14/31] update --- test/test_ops.py | 69 ++++++++++++++++------------- torchao/prototype/fp6_llm/README.md | 18 +++++++- 2 files changed, 55 insertions(+), 32 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 920b32c5f2..3728ba202f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,60 +1,69 @@ import torch -from torch.testing._internal.common_utils import TestCase, IS_FBCODE +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) from torch.testing._internal.optests import opcheck -import torchao +from torchao.utils import is_fbcode from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2 -import unittest -from parameterized import parameterized import pytest +if is_fbcode(): + pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels") + try: import torchao.ops except RuntimeError: pytest.skip("torchao.ops not available") -# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): -# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) -@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning") -@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") class TestOps(TestCase): - def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device): - # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. - fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) - fp16_scale = torch.rand(OC).half() + 0.5 - fp16_activation = torch.rand(BS, IC).half() + 0.5 - return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_llm_linear(self): + def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): + # Randomly initialize each byte + nbits = 1 + ebits + mbits + fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8).view(torch.int32) + scale = torch.rand(OC).half() + 0.5 + fp16_act = torch.rand(BS, IC).half() + 0.5 + return fpx_weight.to(device), scale.to(device), fp16_act.to(device) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + def test_quant_llm_linear(self, ebits, mbits): BS = 2 OC = 256 IC = 256 splitK = 1 - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") + fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") # smoke test - torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) + torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) + opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils) - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py - @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + def test_fp6_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py + fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") - results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) + results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) - fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] - results_fp16 = fp16_activation @ fp16_weight.T + # TODO: add from_scaled_tc_fpx() + fp16_weight = from_tc_float6_e3m2(fpx_weight.view(torch.uint8), dtype=torch.float16) * scale[:, None] + results_fp16 = fp16_act @ fp16_weight.T - error = (results_fp6 - results_fp16).abs() + error = (results_fpx - results_fp16).abs() relative_error = error / results_fp16.abs() assert relative_error.mean() < 1e-2 +instantiate_parametrized_tests(TestOps) + + if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/torchao/prototype/fp6_llm/README.md b/torchao/prototype/fp6_llm/README.md index 2d64d7313a..6797694f0a 100644 --- a/torchao/prototype/fp6_llm/README.md +++ b/torchao/prototype/fp6_llm/README.md @@ -31,11 +31,25 @@ fp16_act = torch.randn(1, 512).cuda().half() outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024) ``` -**NOTE**: since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization. +**NOTE**: since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations. ## Benchmark results -TODO +Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in [_models/llama](../../_models/llama). tokens/s is measured using [generate.py](../../_models/llama/generate.py) which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using [eval.py](../../_models/llama/eval.py) which uses [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness). The model used is [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). + +FPx quantization is run with `--precision float16`. The rest uses the default precision of `bfloat16`. + +Quantization | wikitext perplexity | tokens/s +--------------------|---------------------|---------- +INT8 | 12.21 | 87.45 +INT4-256 (tinygemm) | 76266957.87 (bug) | 157.10 +FP6 E3M2 | 12.34 | 106.76 +FP6 E2M3 | 12.23 | 106.77 +FP5 E3M1 | 12.55 | 122.69 +FP5 E2M2 | 12.47 | 122.66 +FP4 E3M0 | 14.58 | 145.55 +FP4 E2M1 | 15.01 | 146.05 +FP3 E2M0 | 74625.18 | 164.49 ## Credits From d2d8019cea4dad0ee949d160c6971f8b02bd9a75 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 17:58:50 +0800 Subject: [PATCH 15/31] add from tc_fpx --- test/test_ops.py | 6 +-- torchao/prototype/fp6_llm/fp6_llm.py | 55 ++++++++++++++++++++++++---- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3728ba202f..615b6e9796 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,7 +7,7 @@ ) from torch.testing._internal.optests import opcheck from torchao.utils import is_fbcode -from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2 +from torchao.prototype.fp6_llm.fp6_llm import _from_tc_fpx import pytest if is_fbcode(): @@ -53,8 +53,8 @@ def test_fp6_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) - # TODO: add from_scaled_tc_fpx() - fp16_weight = from_tc_float6_e3m2(fpx_weight.view(torch.uint8), dtype=torch.float16) * scale[:, None] + fp32_weight = _from_tc_fpx(fpx_weight.view(torch.uint8), ebits, mbits) * scale[:, None].float() + fp16_weight = fp32_weight.half() results_fp16 = fp16_act @ fp16_weight.T error = (results_fpx - results_fp16).abs() diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index 83d9d12051..e8ec8af942 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -18,7 +18,7 @@ def _unpack(x: Tensor, n_bits: int) -> Tensor: # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 -def _bit_interleave(x: Tensor, n_bits: int) -> Tensor: +def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8 # thus, we need to reverse byte order within a uint32 word. x = x.reshape(-1, 4).flip(1) @@ -30,10 +30,13 @@ def _bit_interleave(x: Tensor, n_bits: int) -> Tensor: 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], - 4: [1, 5, 3, 7, 0, 4, 2, 6] + 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] - x = x[:, bit_order] + if undo: + bit_order = [bit_order.index(i) for i in range(len(bit_order))] + + x = x[:, bit_order] x = _pack(x, n_bits) # reverse byte order within a uint32 word again. @@ -57,18 +60,18 @@ def _to_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tensor: tensor_fpx = tensor_fpx.permute(1, 0, 2) tensor_fpx = tensor_fpx.flatten() - n_bits = 1 + ebits + mbits + total_bits = 1 + ebits + mbits n_used = 0 fragments = [] for y in [1, 2, 4]: - if n_bits & y: + if total_bits & y: mask = (1 << y) - 1 - tensor_ybit = (tensor_fpx >> (n_bits - n_used - y)) & mask + tensor_ybit = (tensor_fpx >> (total_bits - n_used - y)) & mask tensor_ybit = _pack(tensor_ybit, y) tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code - tensor_ybit = _bit_interleave(tensor_ybit, y) # Pass 3 from original code + tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code fragments.append(tensor_ybit) n_used += y @@ -112,6 +115,44 @@ def to_scaled_tc_float6_e3m2(tensor: Tensor) -> Tuple[Tensor, Tensor]: return tc_fp6_tensor, scale.reciprocal().half() +def _from_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tensor: + total_bits = 1 + ebits + mbits + M = tensor.shape[0] + size = tensor.numel() + tensor = tensor.flatten() + offset = 0 + n_used = 0 + + tensor_fpx = None + + for y in [1, 2, 4]: + if total_bits & y: + size_ybit = size // total_bits * y + tensor_ybit = tensor[offset : offset + size_ybit] + offset += size_ybit + + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 + + tensor_ybit = _unpack(tensor_ybit.flatten(), y) + tensor_ybit = tensor_ybit << (total_bits - n_used - y) + n_used += y + + if tensor_fpx is None: + tensor_fpx = tensor_ybit + else: + tensor_fpx |= tensor_ybit + + # undo Pass 1 + tensor_fpx = tensor_fpx.view(32, -1, 2).permute(1, 0, 2) + tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8) + tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6) + tensor_fpx = tensor_fpx.reshape(M, -1) + + tensor_fp32 = _fpx_unpacked_to_f32(tensor_fpx, ebits, mbits) + return tensor_fp32 + + def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> Tensor: assert tensor.ndim == 2 and tensor.dtype == torch.uint8 M = tensor.shape[0] From bb52ad0464fb07154acb6e6dfc96aa55e3d7ad2d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 19:26:42 +0800 Subject: [PATCH 16/31] remove redundant code --- test/test_ops.py | 3 +- torchao/prototype/fp6_llm/fp6_llm.py | 54 +++------------------------- 2 files changed, 6 insertions(+), 51 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 615b6e9796..557296e74f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -53,8 +53,7 @@ def test_fp6_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) - fp32_weight = _from_tc_fpx(fpx_weight.view(torch.uint8), ebits, mbits) * scale[:, None].float() - fp16_weight = fp32_weight.half() + fp16_weight = _from_tc_fpx(fpx_weight.view(torch.uint8), ebits, mbits).half() * scale[:, None] results_fp16 = fp16_act @ fp16_weight.T error = (results_fpx - results_fp16).abs() diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index e8ec8af942..85f709b4c4 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones from torchao.prototype.mx_formats.constants import F6_E3M2_MAX -from torchao.ops import fp6_llm_linear, quant_llm_linear +from torchao.ops import quant_llm_linear def _pack(x: Tensor, n_bits: int) -> Tensor: @@ -315,54 +315,6 @@ def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> T ] -class Fp6LlmLinear(nn.Module): - """FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. - """ - - def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None) -> None: - super().__init__() - self.register_buffer("weight", weight.view(torch.int32)) - self.register_buffer("scales", scales) - self.register_buffer("bias", bias) - self.out_features = weight.shape[0] - self.in_features = weight.shape[1] // 3 * 4 - - def forward(self, x: Tensor) -> Tensor: - splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features) - out = fp6_llm_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK) - if self.bias is not None: - out = out + self.bias - return out.view(*x.shape[:-1], self.out_features).to(x.dtype) - - @staticmethod - def get_split_k(bsize: int, out_dim: int) -> int: - # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 - - @classmethod - def from_float(cls, linear: nn.Linear): - assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) - - fp6_weight, scale = to_scaled_tc_float6_e3m2(linear.weight.detach()) - bias = linear.bias.detach().half() if linear.bias is not None else None - return cls(fp6_weight, scale, bias) - - def extra_repr(self) -> str: - return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' - - -def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[List[str]] = None, cur_fqn: str = "") -> None: - for name, child in model.named_children(): - new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" - - if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): - if (child.in_features % 64 == 0) and (child.out_features % 256 == 0): - new_child = Fp6LlmLinear.from_float(child) - setattr(model, name, new_child) - else: - convert_fp6_llm(child, skip_fqn_list, new_fqn) - - class QuantLlmLinear(nn.Module): """Quant-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. """ @@ -437,3 +389,7 @@ def convert_quant_llm( setattr(model, name, new_child) else: convert_quant_llm(child, ebits, mbits, skip_fqn_list, new_fqn) + + +def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[List[str]] = None, cur_fqn: str = "") -> None: + return convert_quant_llm(model, 3, 2, skip_fqn_list=skip_fqn_list, cur_fqn=cur_fqn) From 80661abd90a36242717899ada0870b65e592c5a7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 20:01:38 +0800 Subject: [PATCH 17/31] fix import --- torchao/prototype/fp6_llm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/fp6_llm/__init__.py b/torchao/prototype/fp6_llm/__init__.py index befdfbec40..5f4a47dd5e 100644 --- a/torchao/prototype/fp6_llm/__init__.py +++ b/torchao/prototype/fp6_llm/__init__.py @@ -1 +1 @@ -from .fp6_llm import Fp6LlmLinear, convert_fp6_llm, to_scaled_tc_float6_e3m2, convert_quant_llm +from .fp6_llm import convert_fp6_llm, convert_quant_llm, to_scaled_tc_float6_e3m2 From edfbe3d1ca838ed9badbe974db749b7075ccfb5e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 20:35:21 +0800 Subject: [PATCH 18/31] fix test --- test/prototype/test_fp6_llm.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/test/prototype/test_fp6_llm.py b/test/prototype/test_fp6_llm.py index fdb228023e..ce693e0a69 100644 --- a/test/prototype/test_fp6_llm.py +++ b/test/prototype/test_fp6_llm.py @@ -11,8 +11,8 @@ to_tc_float6_e3m2, from_tc_float6_e3m2, _to_tc_fpx, - Fp6LlmLinear, - convert_fp6_llm, + QuantLlmLinear, + convert_quant_llm, ) from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked @@ -20,7 +20,7 @@ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) -class TestFp6LlmLinear(TestCase): +class TestQuantLlmLinear(TestCase): @parametrize("device", _DEVICES) def test_to_tc_float6_e3m2_correctness(self, device): x = torch.randn(256, 64, device=device) @@ -59,12 +59,13 @@ def test_from_tc_float6_e3m2_compile(self, device): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("leading_dims", [(4,), (2, 4)]) @parametrize("bias", [False, True]) - def test_fp6_llm_linear_forward(self, bias, leading_dims): + def test_quant_llm_linear_forward(self, bias, leading_dims): OC, IC = 256, 64 device = "cuda" + ebits, mbits = 3, 2 linear = torch.nn.Linear(IC, OC, bias=bias, device=device) - fp6_linear = Fp6LlmLinear.from_float(linear) + fp6_linear = QuantLlmLinear.from_float(linear, mbits, ebits) assert (fp6_linear.bias is not None) == bias x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) @@ -72,12 +73,13 @@ def test_fp6_llm_linear_forward(self, bias, leading_dims): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("bias", [False, True]) - def test_fp6_llm_linear_compile(self, bias): + def test_quant_llm_linear_compile(self, bias): N, OC, IC = 4, 256, 64 device = "cuda" + ebits, mbits = 3, 2 linear = torch.nn.Linear(IC, OC, bias=bias, device=device) - fp6_linear = Fp6LlmLinear.from_float(linear) + fp6_linear = QuantLlmLinear.from_float(linear, ebits, mbits) x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fp6_linear(x) @@ -85,21 +87,23 @@ def test_fp6_llm_linear_compile(self, bias): torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_convert_fp6_llm(self): + def test_convert_quant_llm(self): device = "cuda" + ebits, mbits = 3, 2 + model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device) - convert_fp6_llm(model) + convert_quant_llm(model, ebits, mbits) - assert isinstance(model[0], Fp6LlmLinear) + assert isinstance(model[0], QuantLlmLinear) assert model[0].bias is None - assert isinstance(model[1], Fp6LlmLinear) + assert isinstance(model[1], QuantLlmLinear) assert model[1].bias is not None x = torch.randn(4, 64, device=device) model(x) -instantiate_parametrized_tests(TestFp6LlmLinear) +instantiate_parametrized_tests(TestQuantLlmLinear) if __name__ == "__main__": From 0ecdd860b6353e357ff5437511940c6b523041f9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 19 Jun 2024 22:06:00 +0800 Subject: [PATCH 19/31] avoid division by 0 --- test/test_ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 557296e74f..b4ad18a7e6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -48,7 +48,7 @@ def test_quant_llm_linear(self, ebits, mbits): @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) def test_fp6_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) @@ -56,9 +56,10 @@ def test_fp6_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): fp16_weight = _from_tc_fpx(fpx_weight.view(torch.uint8), ebits, mbits).half() * scale[:, None] results_fp16 = fp16_act @ fp16_weight.T - error = (results_fpx - results_fp16).abs() - relative_error = error / results_fp16.abs() - assert relative_error.mean() < 1e-2 + error = (results_fpx - results_fp16).abs().mean() + gt = results_fp16.abs().mean() + relative_error = error / gt + assert relative_error < 1e-3 instantiate_parametrized_tests(TestOps) From e6c7d6b5614435ebe5e84e848f9731fdd582ce01 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 20:10:18 +0800 Subject: [PATCH 20/31] add subclass. use uint8 --- test/test_ops.py | 2 +- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 21 ++--- torchao/ops.py | 8 +- torchao/prototype/fp6_llm/fp6_llm.py | 109 +++++++++++++++++++++++- 4 files changed, 123 insertions(+), 17 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index b4ad18a7e6..58c0c571bb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -23,7 +23,7 @@ class TestOps(TestCase): def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): # Randomly initialize each byte nbits = 1 + ebits + mbits - fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8).view(torch.int32) + fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) scale = torch.rand(OC).half() + 0.5 fp16_act = torch.rand(BS, IC).half() + 0.5 return fpx_weight.to(device), scale.to(device), fp16_act.to(device) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 3f0c976e78..44cd2e39a2 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -119,6 +119,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, #include namespace torchao { +// MODIFICATION NOTE: dtype of _weights is changed to uint8 /* Computes FPx-FP16 GEMM (PyTorch interface). @@ -128,7 +129,7 @@ After Equivalent transformation : trans(Out) = W * trans(In). Note that we [Inputs] _in_feats: tensor of shape [B, IC]; // half - _weights: int tensor of shape [OC, IC // 32 * x]; // x INT32 words contains 32 FPx weights. + _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words contains 8 FPx weights. _scales: tensor of shape [OC]; // half splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] @@ -142,18 +143,18 @@ torch::Tensor fp_eXmY_linear_forward_cuda( torch::Tensor _scales, int64_t splitK=1) { - const int64_t NBITS = 1 + EXPONENT + MANTISSA; + const int64_t NBITS = 1 + EXPONENT + MANTISSA; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); int num_out_channels = _weights.size(0); TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); - TORCH_CHECK((num_in_channels / 32 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. + TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. // int M = num_out_channels; int K = num_in_channels; int N = num_in_feats; // Input Tensors - auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto scales = reinterpret_cast(_scales.data_ptr()); // Output Tensors @@ -176,12 +177,12 @@ torch::Tensor fp_eXmY_linear_forward_cuda( fpx_linear_kernel<2, 3>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 3 && MANTISSA == 1) fpx_linear_kernel<3, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 2 && MANTISSA == 1) - fpx_linear_kernel<2, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 3 && MANTISSA == 0) - fpx_linear_kernel<3, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 2 && MANTISSA == 0) - fpx_linear_kernel<2, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 2 && MANTISSA == 1) + // fpx_linear_kernel<2, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 3 && MANTISSA == 0) + // fpx_linear_kernel<3, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 2 && MANTISSA == 0) + // fpx_linear_kernel<2, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); diff --git a/torchao/ops.py b/torchao/ops.py index 127f54192b..1f118709e9 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -18,7 +18,7 @@ def fp6_llm_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: Arguments _in_feats: input activations in FP16 - _weights: packed FP6 weights. See :func:prepack_fp6_weight and :func:fp16_to_fp6 + _weights: packed FP6 weights _scales: scale splitK: split K @@ -43,7 +43,7 @@ def quant_llm_linear( EXPONENT: number of exponent bits MANTISSA: number of mantissa bits _in_feats: input activations in FP16 - _weights: packed FP6 weights. See :func:prepack_fp6_weight and :func:fp16_to_fp6 + _weights: packed FPx weights _scales: scale splitK: split K @@ -58,14 +58,14 @@ def _(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK = 1): torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") - torch._check(_weights.dtype is torch.int32, lambda: f"weight must be INT32, got {_weights.dtype}") + torch._check(_weights.dtype is torch.uint8, lambda: f"weight must be UINT8, got {_weights.dtype}") torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") torch._check(_scales.dtype is torch.float16, lambda: f"scale must be FP16, got {_scales.dtype}") BS, IC = _in_feats.shape OC, _ = _weights.shape N_BITS = 1 + EXPONENT + MANTISSA - torch._check(IC // 32 * N_BITS == _weights.shape[1], lambda: "Dimensions mismatched") + torch._check(IC // 8 * N_BITS == _weights.shape[1], lambda: "Dimensions mismatched") torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index 85f709b4c4..63d7509cf7 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -4,9 +4,11 @@ import torch from torch import nn, Tensor +from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones from torchao.prototype.mx_formats.constants import F6_E3M2_MAX from torchao.ops import quant_llm_linear +from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE def _pack(x: Tensor, n_bits: int) -> Tensor: @@ -160,7 +162,7 @@ def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> T assert (M % 64 == 0) and (N % 64 == 0) size_2bit = M * N // 4 size_4bit = M * N // 2 - tensor = tensor.view(-1).view(torch.uint8) + tensor = tensor.view(-1) assert tensor.numel() == size_2bit + size_4bit tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) @@ -315,6 +317,109 @@ def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> T ] +class QuantLlmLinearWeight(Tensor): + _implements = classmethod(_implements) + + @staticmethod + def new(cls, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): + assert fpx_data.ndim == 2 + assert fpx_data.dtype == torch.uint8 + shape = (fpx_data.shape[0], fpx_data.shape[1] // (1 + ebits + mbits) * 8) + + return Tensor._make_wrapper_subclass( + cls, + shape, + device=fpx_data.device, + requires_grad=False, + ) + + def init(self, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): + self.fpx_data = fpx_data + self.scale = scale + self.ebits = ebits + self.mbits = mbits + + def __tensor_flatten__(self): + return ["fpx_data", "scale"], [self.ebits, self.mbits] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride): + return cls(tensor_data_dict["fpx_data"], tensor_data_dict["scale"], *tensor_attributes) + + @classmethod + def from_float(cls, input_float: Tensor, ebits: int, mbits: int): + fpx_data, scale = _to_scaled_tc_fpx(input_float, ebits, mbits) + return cls(fpx_data, scale, ebits, mbits) + + def dequantize(self, output_dtype=None): + output_dtype = output_dtype or torch.get_default_dtype() + return _from_tc_fpx(self.fpx_data, self.ebits, self.mbits) * self.scale.view(-1, 1) + + def repr(self): + dtype = f"fp{1 + self.ebits + self.mbits}_e{self.ebits}m{self.mbits}" + return ( + f"{self.__class__.name}(dtype={dtype}, shape={self.shape}, " + f"device={self.device}, requires_grad={self.requires_grad})" + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.fpx_data), + fn(self.scale), + self.ebits, + self.mbits, + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) + + raise NotImplementedError(f"{cls.name} dispatch: attempting to run {func}, this is not supported") + + +@QuantLlmLinearWeight._implements(torch.nn.functional.linear) +def _(*args, **kwargs): + act, weight, bias = args + assert isinstance(weight, QuantLlmLinearWeight) + + out_dim, in_dim = weight.shape + act_reshaped = act.view(-1, in_dim).half() + + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + bsize = act_reshaped.shape[0] + splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + + out = quant_llm_linear( + weight.ebits, + weight.mbits, + act_reshaped, + weight.fpx_data, + weight.scale, + splitK=splitK, + ) + + if bias is not None: + out += bias + + return out.view(*act.shape[:-1], out_dim).to(act.dtype) + + +@QuantLlmLinearWeight._implements(torch.ops.aten.detach.default) +def _(func, *args, **kwargs): + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) + + class QuantLlmLinear(nn.Module): """Quant-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. """ @@ -328,7 +433,7 @@ def __init__( bias: Optional[Tensor] = None, ) -> None: super().__init__() - self.register_buffer("weight", weight.view(torch.int32)) + self.register_buffer("weight", weight) self.register_buffer("scales", scales) self.register_buffer("bias", bias) self.out_features = weight.shape[0] From ca43bf886f37c2a1f76ba19cfacac594d55ae395 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 20:41:54 +0800 Subject: [PATCH 21/31] subclass API --- test/prototype/test_fp6_llm.py | 50 ++++++-------- torchao/prototype/fp6_llm/README.md | 6 +- torchao/prototype/fp6_llm/__init__.py | 2 +- torchao/prototype/fp6_llm/fp6_llm.py | 96 ++++----------------------- 4 files changed, 39 insertions(+), 115 deletions(-) diff --git a/test/prototype/test_fp6_llm.py b/test/prototype/test_fp6_llm.py index ce693e0a69..4959d6752f 100644 --- a/test/prototype/test_fp6_llm.py +++ b/test/prototype/test_fp6_llm.py @@ -1,6 +1,7 @@ +import copy + import pytest import torch -from torch import nn from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, @@ -8,16 +9,18 @@ run_tests, ) from torchao.prototype.fp6_llm.fp6_llm import ( + QuantLlmLinearWeight, + quant_llm_fpx_weight_only, to_tc_float6_e3m2, from_tc_float6_e3m2, _to_tc_fpx, - QuantLlmLinear, - convert_quant_llm, ) from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked +from torchao.quantization.quant_api import quantize _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_FPx_DTYPES = [(3, 2), (2, 2)] class TestQuantLlmLinear(TestCase): @@ -57,51 +60,38 @@ def test_from_tc_float6_e3m2_compile(self, device): torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", _FPx_DTYPES) @parametrize("leading_dims", [(4,), (2, 4)]) @parametrize("bias", [False, True]) - def test_quant_llm_linear_forward(self, bias, leading_dims): + def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims): OC, IC = 256, 64 device = "cuda" - ebits, mbits = 3, 2 - linear = torch.nn.Linear(IC, OC, bias=bias, device=device) - fp6_linear = QuantLlmLinear.from_float(linear, mbits, ebits) - assert (fp6_linear.bias is not None) == bias + fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half) + fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None + + fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits) x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) - fp6_linear(x) + out = torch.nn.functional.linear(x, fpx_weight, fp16_bias) + assert out.shape == leading_dims + (OC,) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", _FPx_DTYPES) @parametrize("bias", [False, True]) - def test_quant_llm_linear_compile(self, bias): + def test_quant_llm_quantize(self, ebits, mbits, bias): N, OC, IC = 4, 256, 64 device = "cuda" - ebits, mbits = 3, 2 linear = torch.nn.Linear(IC, OC, bias=bias, device=device) - fp6_linear = QuantLlmLinear.from_float(linear, ebits, mbits) + fpx_linear = copy.deepcopy(linear) + quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) x = torch.randn(N, IC, device=device, dtype=torch.half) - expected = fp6_linear(x) - actual = torch.compile(fp6_linear, fullgraph=True)(x) + expected = fpx_linear(x) + actual = torch.compile(fpx_linear, fullgraph=True)(x) torch.testing.assert_close(actual, expected) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_convert_quant_llm(self): - device = "cuda" - ebits, mbits = 3, 2 - - model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device) - convert_quant_llm(model, ebits, mbits) - - assert isinstance(model[0], QuantLlmLinear) - assert model[0].bias is None - assert isinstance(model[1], QuantLlmLinear) - assert model[1].bias is not None - - x = torch.randn(4, 64, device=device) - model(x) - instantiate_parametrized_tests(TestQuantLlmLinear) diff --git a/torchao/prototype/fp6_llm/README.md b/torchao/prototype/fp6_llm/README.md index 6797694f0a..b64f46155c 100644 --- a/torchao/prototype/fp6_llm/README.md +++ b/torchao/prototype/fp6_llm/README.md @@ -5,10 +5,11 @@ This is a FP16 x FP6 mixed matmul kernel optimized for io bound workloads per [F ## Usage ```python -from torchao.prototype.fp6_llm import convert_fp6_llm +from torchao.quantization.quant_api import quantize +from torchao.prototype.fp6_llm import fp6_llm_weight_only model = ... -convert_fp6_llm(model) # convert model in-place, replacing nn.Linear modules with Fp6LlmLinear +quantize(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 in-place # fully compatible with torch.compile() model.compile(mode="max-autotune", fullgraph=True) @@ -17,6 +18,7 @@ model.compile(mode="max-autotune", fullgraph=True) It's also possible to pre-process the weight and call the kernel directly. ```python +# TODO: update import torch from torchao.prototype.fp6_llm import to_scaled_tc_float6_e3m2 from torchao.ops import fp6_llm_linear diff --git a/torchao/prototype/fp6_llm/__init__.py b/torchao/prototype/fp6_llm/__init__.py index 5f4a47dd5e..7ed84c9dcc 100644 --- a/torchao/prototype/fp6_llm/__init__.py +++ b/torchao/prototype/fp6_llm/__init__.py @@ -1 +1 @@ -from .fp6_llm import convert_fp6_llm, convert_quant_llm, to_scaled_tc_float6_e3m2 +from .fp6_llm import QuantLlmLinearWeight, fp6_llm_weight_only, quant_llm_fpx_weight_only diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index 63d7509cf7..e474cbfc55 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -1,9 +1,8 @@ from functools import reduce -import math -from typing import List, Optional, Tuple +from typing import Tuple import torch -from torch import nn, Tensor +from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones from torchao.prototype.mx_formats.constants import F6_E3M2_MAX @@ -321,7 +320,7 @@ class QuantLlmLinearWeight(Tensor): _implements = classmethod(_implements) @staticmethod - def new(cls, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): + def __new__(cls, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): assert fpx_data.ndim == 2 assert fpx_data.dtype == torch.uint8 shape = (fpx_data.shape[0], fpx_data.shape[1] // (1 + ebits + mbits) * 8) @@ -333,7 +332,7 @@ def new(cls, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): requires_grad=False, ) - def init(self, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): + def __init__(self, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): self.fpx_data = fpx_data self.scale = scale self.ebits = ebits @@ -355,7 +354,7 @@ def dequantize(self, output_dtype=None): output_dtype = output_dtype or torch.get_default_dtype() return _from_tc_fpx(self.fpx_data, self.ebits, self.mbits) * self.scale.view(-1, 1) - def repr(self): + def __repr__(self): dtype = f"fp{1 + self.ebits + self.mbits}_e{self.ebits}m{self.mbits}" return ( f"{self.__class__.name}(dtype={dtype}, shape={self.shape}, " @@ -420,81 +419,14 @@ def _(func, *args, **kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) -class QuantLlmLinear(nn.Module): - """Quant-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. - """ - - def __init__( - self, - ebits: int, - mbits: int, - weight: Tensor, - scales: Tensor, - bias: Optional[Tensor] = None, - ) -> None: - super().__init__() - self.register_buffer("weight", weight) - self.register_buffer("scales", scales) - self.register_buffer("bias", bias) - self.out_features = weight.shape[0] - self.in_features = weight.shape[1] // (1 + ebits + mbits) * 8 - self.ebits = ebits - self.mbits = mbits - - def forward(self, x: Tensor) -> Tensor: - splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features) - out = quant_llm_linear( - self.ebits, - self.mbits, - x.view(-1, self.in_features).half(), - self.weight, - self.scales, - splitK=splitK, - ) - if self.bias is not None: - out = out + self.bias - return out.view(*x.shape[:-1], self.out_features).to(x.dtype) - - @staticmethod - def get_split_k(bsize: int, out_dim: int) -> int: - # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 - - @classmethod - def from_float(cls, linear: nn.Linear, ebits: int, mbits: int): - assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) - - fpx_weight, scale = _to_scaled_tc_fpx(linear.weight.detach(), ebits, mbits) - bias = linear.bias.detach().half() if linear.bias is not None else None - return cls(ebits, mbits, fpx_weight, scale, bias) - - def extra_repr(self) -> str: - return ( - f'in_features={self.in_features}' - f', out_features={self.out_features}' - f', bias={self.bias is not None}' - f', ebits={self.ebits}' - f', mbits={self.mbits}' - ) - - -def convert_quant_llm( - model: nn.Module, - ebits: int, - mbits: int, - skip_fqn_list: Optional[List[str]] = None, - cur_fqn: str = "", -) -> None: - for name, child in model.named_children(): - new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" - - if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): - if (child.in_features % 64 == 0) and (child.out_features % 256 == 0): - new_child = QuantLlmLinear.from_float(child, ebits, mbits) - setattr(model, name, new_child) - else: - convert_quant_llm(child, ebits, mbits, skip_fqn_list, new_fqn) +def quant_llm_fpx_weight_only(ebits: int, mbits: int): + def apply_quant_llm(weight: Tensor) -> Tensor: + out_dim, in_dim = weight.shape + if (in_dim % 64 != 0) or (out_dim % 256 != 0): + return weight + return QuantLlmLinearWeight.from_float(weight, ebits, mbits) + return apply_quant_llm -def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[List[str]] = None, cur_fqn: str = "") -> None: - return convert_quant_llm(model, 3, 2, skip_fqn_list=skip_fqn_list, cur_fqn=cur_fqn) +def fp6_llm_weight_only(): + return quant_llm_fpx_weight_only(3, 2) From 8de272237ba36f9877a10aa9e43ae270d248aa2c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 21:04:42 +0800 Subject: [PATCH 22/31] update doc --- torchao/prototype/fp6_llm/README.md | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/torchao/prototype/fp6_llm/README.md b/torchao/prototype/fp6_llm/README.md index b64f46155c..fc20aa626c 100644 --- a/torchao/prototype/fp6_llm/README.md +++ b/torchao/prototype/fp6_llm/README.md @@ -1,15 +1,19 @@ -# FP6-LLM +# Quant-LLM -This is a FP16 x FP6 mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32 weights to FP6 and facility to convert existing models to FP6. +This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32/FP16/BF16 weights to FPx and integration with torchao API. ## Usage ```python from torchao.quantization.quant_api import quantize -from torchao.prototype.fp6_llm import fp6_llm_weight_only +from torchao.prototype.fp6_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only model = ... -quantize(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 in-place +model.half() # not necessary, but recommeneded to maintain accuracy +quantize(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place + +# for generic FPx EyMz where x = 1 + y + z +# quantize(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead # fully compatible with torch.compile() model.compile(mode="max-autotune", fullgraph=True) @@ -18,22 +22,24 @@ model.compile(mode="max-autotune", fullgraph=True) It's also possible to pre-process the weight and call the kernel directly. ```python -# TODO: update import torch -from torchao.prototype.fp6_llm import to_scaled_tc_float6_e3m2 -from torchao.ops import fp6_llm_linear +from torchao.prototype.fp6_llm.fp6_llm import _to_scaled_tc_fpx +from torchao.ops import quant_llm_linear fp32_weight = torch.randn(1024, 512).cuda() +ebits, mbits = 3, 2 # pre-process the weight. this will quantize the weight to FP6 and pack it in a special # layout for tensor cores. refer to paper for more details. -fp6_weight, scales = to_scaled_tc_float6_e3m2(fp32_weight) +fp6_weight, scales = _to_scaled_tc_fpx(fp32_weight, ebits, mbits) fp16_act = torch.randn(1, 512).cuda().half() -outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024) +outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape (1, 1024) ``` -**NOTE**: since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations. +**NOTE**: +- Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations. +- Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1. ## Benchmark results @@ -44,7 +50,7 @@ FPx quantization is run with `--precision float16`. The rest uses the default pr Quantization | wikitext perplexity | tokens/s --------------------|---------------------|---------- INT8 | 12.21 | 87.45 -INT4-256 (tinygemm) | 76266957.87 (bug) | 157.10 +INT4-256 (tinygemm) | -- | 157.10 FP6 E3M2 | 12.34 | 106.76 FP6 E2M3 | 12.23 | 106.77 FP5 E3M1 | 12.55 | 122.69 From 50bfe822cdb61b31b57fa7291c198011ec4a6519 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 21:05:25 +0800 Subject: [PATCH 23/31] remove unused op --- torchao/ops.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/torchao/ops.py b/torchao/ops.py index 1f118709e9..3145812a2f 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -12,22 +12,6 @@ def decorator(func): return decorator -def fp6_llm_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: - """ - FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. - - Arguments - _in_feats: input activations in FP16 - _weights: packed FP6 weights - _scales: scale - splitK: split K - - Returns - output of linear layer - """ - return quant_llm_linear(3, 2, _in_feats, _weights, _scales, splitK) - - def quant_llm_linear( EXPONENT: int, MANTISSA: int, From b9375a4f0b93b9d25901ceeb5b2a42175587938e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 22:06:08 +0800 Subject: [PATCH 24/31] update --- test/prototype/test_fp6_llm.py | 54 +++++++----- torchao/prototype/custom_fp_utils.py | 4 +- torchao/prototype/fp6_llm/fp6_llm.py | 125 ++++++++++++++++----------- 3 files changed, 106 insertions(+), 77 deletions(-) diff --git a/test/prototype/test_fp6_llm.py b/test/prototype/test_fp6_llm.py index 4959d6752f..b17591960a 100644 --- a/test/prototype/test_fp6_llm.py +++ b/test/prototype/test_fp6_llm.py @@ -8,14 +8,14 @@ parametrize, run_tests, ) +from torchao.prototype.fp6_llm import QuantLlmLinearWeight, quant_llm_fpx_weight_only from torchao.prototype.fp6_llm.fp6_llm import ( - QuantLlmLinearWeight, - quant_llm_fpx_weight_only, - to_tc_float6_e3m2, - from_tc_float6_e3m2, - _to_tc_fpx, + _pack_tc_fpx, + _pack_tc_fp6, + to_scaled_tc_fpx, + from_scaled_tc_fpx, ) -from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 from torchao.quantization.quant_api import quantize @@ -23,40 +23,46 @@ _FPx_DTYPES = [(3, 2), (2, 2)] -class TestQuantLlmLinear(TestCase): +class TestQuantLlmLinearWeight(TestCase): @parametrize("device", _DEVICES) - def test_to_tc_float6_e3m2_correctness(self, device): - x = torch.randn(256, 64, device=device) + def test_pack_tc_fp6_correctness(self, device): + x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device) - expected = _to_tc_fpx(x, 3, 2) - actual = to_tc_float6_e3m2(x) + expected = _pack_tc_fpx(x, 6) + actual = _pack_tc_fp6(x) torch.testing.assert_close(actual, expected) + @parametrize("ebits,mbits", _FPx_DTYPES) @parametrize("device", _DEVICES) - def test_to_tc_float6_e3m2_compile(self, device): + def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device): x = torch.randn(256, 64, device=device) - expected = to_tc_float6_e3m2(x) - actual = torch.compile(to_tc_float6_e3m2, fullgraph=True)(x) + expected = to_scaled_tc_fpx(x, ebits, mbits) + actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits) torch.testing.assert_close(actual, expected) + @parametrize("ebits,mbits", _FPx_DTYPES) @parametrize("device", _DEVICES) - def test_from_tc_float6_e3m2_correctness(self, device): - x = torch.randn(256, 64, device=device) + def test_from_tc_fpx_correctness(self, ebits, mbits, device): + x = torch.randn(256, 64, device=device) * 100 - # quantize and dequantize so that the values are exactly representable in FP6 - x = f6_e3m2_unpacked_to_f32(f32_to_f6_e3m2_unpacked(x)) + # quantize and dequantize so that the values are exactly representable in FPx + x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits) - actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x)) + tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits) + actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale) torch.testing.assert_close(actual, x) + @parametrize("ebits,mbits", _FPx_DTYPES) @parametrize("device", _DEVICES) - def test_from_tc_float6_e3m2_compile(self, device): + def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device): M, N = 256, 64 - x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device) + nbits = 1 + ebits + mbits + x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device) + scale = torch.randn(M, device=device) - expected = from_tc_float6_e3m2(x) - actual = torch.compile(from_tc_float6_e3m2, fullgraph=True)(x) + expected = from_scaled_tc_fpx(x, ebits, mbits, scale) + actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale) torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -93,7 +99,7 @@ def test_quant_llm_quantize(self, ebits, mbits, bias): torch.testing.assert_close(actual, expected) -instantiate_parametrized_tests(TestQuantLlmLinear) +instantiate_parametrized_tests(TestQuantLlmLinearWeight) if __name__ == "__main__": diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index 1a3e9e34cb..3af11f1710 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -216,7 +216,9 @@ def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 # we can update this in-place since the values won't overlap - mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32 + # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int' + # thus we use + instead of | here + mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32 result = torch.where(denormal_mask, mantissa_lp_int32, result) diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py index e474cbfc55..c336680b73 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/fp6_llm/fp6_llm.py @@ -5,11 +5,13 @@ from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones -from torchao.prototype.mx_formats.constants import F6_E3M2_MAX from torchao.ops import quant_llm_linear from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE +_ONES_TABLE = [_n_ones(i) for i in range(8)] + + def _pack(x: Tensor, n_bits: int) -> Tensor: return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) @@ -47,88 +49,93 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h -def _to_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tensor: - assert tensor.ndim == 2 +def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fpx = _f32_to_fpx_unpacked(tensor.float(), ebits, mbits) - # Pass 1 from original code - tensor_fpx = tensor_fpx.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor_fpx = tensor_fpx.permute(0, 4, 1, 5, 2, 3, 6) - tensor_fpx = tensor_fpx.reshape(-1, 32, 2) - tensor_fpx = tensor_fpx.permute(1, 0, 2) - tensor_fpx = tensor_fpx.flatten() - - total_bits = 1 + ebits + mbits - n_used = 0 + tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6) + tensor = tensor.reshape(-1, 32, 2) + tensor = tensor.permute(1, 0, 2) + tensor = tensor.flatten() + + used_bits = 0 fragments = [] for y in [1, 2, 4]: - if total_bits & y: + if nbits & y: mask = (1 << y) - 1 - tensor_ybit = (tensor_fpx >> (total_bits - n_used - y)) & mask + tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code fragments.append(tensor_ybit) - n_used += y + used_bits += y return torch.cat(fragments, dim=0).view(M, -1) -def _to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: - exp_bias = _n_ones(ebits - 1) - max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) - - scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal - tc_fpx_tensor = _to_tc_fpx(tensor / scale.view(-1, 1), ebits, mbits) - return tc_fpx_tensor, scale.half() - - -# more optimized version of _to_tc_float6_e3m2_original() by merging ops -# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2 +# more optimized version of _pack_tc_fpx() for FP6 by merging ops +def _pack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fp6 = _f32_to_fpx_unpacked(tensor.float(), 3, 2) - tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) - tensor_fp6 = tensor_fp6.flip(3) + tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) + tensor = tensor.flip(3) - tensor_2bit = (tensor_fp6 >> 4) & 0b11 + tensor_2bit = (tensor >> 4) & 0b11 tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) tensor_2bit = _pack(tensor_2bit.flatten(), 2) - tensor_4bit = tensor_fp6 & 0b1111 + tensor_4bit = tensor & 0b1111 tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) tensor_4bit = _pack(tensor_4bit.flatten(), 4) return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) -def to_scaled_tc_float6_e3m2(tensor: Tensor) -> Tuple[Tensor, Tensor]: - scale = F6_E3M2_MAX / tensor.abs().amax(1).clamp(min=1e-12) - tc_fp6_tensor = to_tc_float6_e3m2(tensor * scale.view(-1, 1)) - return tc_fp6_tensor, scale.reciprocal().half() +# currently only optimize for TC-FP6 packing +def pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _pack_tc_fp6(tensor) + return _pack_tc_fpx(tensor, nbits) + +def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: + # _n_ones() is not compatible with torch.compile() due to << operator + # https://github.com/pytorch/pytorch/issues/119152 + # exp_bias = _n_ones(ebits - 1) + # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) -def _from_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tensor: - total_bits = 1 + ebits + mbits + # workaround: global lookup table + exp_bias = _ONES_TABLE[ebits - 1] + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + + tensor = tensor.float() + scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal + tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) + tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits) + return tensor_tc_fpx, scale.half() + + +# inverse of _pack_tc_fpx() +def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 M = tensor.shape[0] size = tensor.numel() tensor = tensor.flatten() offset = 0 - n_used = 0 + used_bits = 0 tensor_fpx = None for y in [1, 2, 4]: - if total_bits & y: - size_ybit = size // total_bits * y + if nbits & y: + size_ybit = size // nbits * y tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit @@ -136,8 +143,8 @@ def _from_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tensor: tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) - tensor_ybit = tensor_ybit << (total_bits - n_used - y) - n_used += y + tensor_ybit = tensor_ybit << (nbits - used_bits - y) + used_bits += y if tensor_fpx is None: tensor_fpx = tensor_ybit @@ -149,12 +156,12 @@ def _from_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tensor: tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8) tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6) tensor_fpx = tensor_fpx.reshape(M, -1) + return tensor_fpx - tensor_fp32 = _fpx_unpacked_to_f32(tensor_fpx, ebits, mbits) - return tensor_fp32 - -def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> Tensor: +# more optimized version of _unpack_tc_fpx() for FP6 by merging ops +# inverse of _unpack_tc_fp6() +def _unpack_tc_fp6(tensor: Tensor) -> Tensor: assert tensor.ndim == 2 and tensor.dtype == torch.uint8 M = tensor.shape[0] N = tensor.shape[1] // 3 * 4 @@ -176,7 +183,21 @@ def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> T tensor_fp6 = (tensor_2bit << 4) | tensor_4bit tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) - return _fpx_unpacked_to_f32(tensor_fp6, 3, 2).to(dtype) + return tensor_fp6 + + +def unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _unpack_tc_fp6(tensor) + return _unpack_tc_fpx(tensor, nbits) + + +def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor: + fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits) + tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits) + if scale is not None: + tensor = tensor * scale.float().view(-1, 1) + return tensor # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -347,12 +368,12 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size, o @classmethod def from_float(cls, input_float: Tensor, ebits: int, mbits: int): - fpx_data, scale = _to_scaled_tc_fpx(input_float, ebits, mbits) + fpx_data, scale = to_scaled_tc_fpx(input_float, ebits, mbits) return cls(fpx_data, scale, ebits, mbits) def dequantize(self, output_dtype=None): output_dtype = output_dtype or torch.get_default_dtype() - return _from_tc_fpx(self.fpx_data, self.ebits, self.mbits) * self.scale.view(-1, 1) + return from_scaled_tc_fpx(self.fpx_data, self.ebits, self.mbits, self.scale).to(output_dtype) def __repr__(self): dtype = f"fp{1 + self.ebits + self.mbits}_e{self.ebits}m{self.mbits}" From 3072257bbeed2044e35683b506a5744f19dc31f4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 22:25:54 +0800 Subject: [PATCH 25/31] rename. update --- benchmarks/benchmark_fp6_llm.py | 29 +++++++++---------- .../{test_fp6_llm.py => test_quant_llm.py} | 8 ++--- test/test_ops.py | 4 +-- torchao/prototype/fp6_llm/__init__.py | 1 - .../{fp6_llm => quant_llm}/README.md | 6 ++-- torchao/prototype/quant_llm/__init__.py | 1 + .../fp6_llm.py => quant_llm/quant_llm.py} | 4 ++- 7 files changed, 27 insertions(+), 26 deletions(-) rename test/prototype/{test_fp6_llm.py => test_quant_llm.py} (95%) delete mode 100644 torchao/prototype/fp6_llm/__init__.py rename torchao/prototype/{fp6_llm => quant_llm}/README.md (93%) create mode 100644 torchao/prototype/quant_llm/__init__.py rename torchao/prototype/{fp6_llm/fp6_llm.py => quant_llm/quant_llm.py} (99%) diff --git a/benchmarks/benchmark_fp6_llm.py b/benchmarks/benchmark_fp6_llm.py index ae17764e68..b6b99c0ebe 100644 --- a/benchmarks/benchmark_fp6_llm.py +++ b/benchmarks/benchmark_fp6_llm.py @@ -1,25 +1,24 @@ import torch -from torch import nn -from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2 -from torch.utils.benchmark import Timer import pandas as pd +import torch.nn.functional as F +from torchao.prototype.quant_llm import QuantLlmLinearWeight +from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm def benchmark(m: int, k: int, n: int): - fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda") - scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5 - fp6_linear = Fp6LlmLinear(fp6_weight, scales) + fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda") + scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5 + fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2) - fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda") - fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None] + fp16_weight = fp6_weight.dequantize(torch.half) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") - fp6_output = fp6_linear(fp16_act) - fp16_output = fp16_linear(fp16_act) + fp6_output = F.linear(fp16_act, fp6_weight) + fp16_output = F.linear(fp16_act, fp16_weight) - fp6_measurement = Timer(stmt="fp6_linear(fp16_act)", globals=locals()).blocked_autorange() - fp16_measurement = Timer(stmt="fp16_linear(fp16_act)", globals=locals()).blocked_autorange() + fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight) + fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight) # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py # doesn't seem to be the right way to check for correctness @@ -29,9 +28,9 @@ def benchmark(m: int, k: int, n: int): "m": m, "k": k, "n": n, - "fp6_latency (ms)": fp6_measurement.median * 1000, - "fp16_latency (ms)": fp16_measurement.median * 1000, - "speedup (d/s)": fp16_measurement.median / fp6_measurement.median, + "fp6_latency (ms)": fp6_time, + "fp16_latency (ms)": fp16_time, + "speedup (d/s)": fp16_time / fp6_time, "correct": correct, } diff --git a/test/prototype/test_fp6_llm.py b/test/prototype/test_quant_llm.py similarity index 95% rename from test/prototype/test_fp6_llm.py rename to test/prototype/test_quant_llm.py index b17591960a..77eac6f69d 100644 --- a/test/prototype/test_fp6_llm.py +++ b/test/prototype/test_quant_llm.py @@ -8,13 +8,13 @@ parametrize, run_tests, ) -from torchao.prototype.fp6_llm import QuantLlmLinearWeight, quant_llm_fpx_weight_only -from torchao.prototype.fp6_llm.fp6_llm import ( - _pack_tc_fpx, - _pack_tc_fp6, +from torchao.prototype.quant_llm import ( + QuantLlmLinearWeight, + quant_llm_fpx_weight_only, to_scaled_tc_fpx, from_scaled_tc_fpx, ) +from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6 from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 from torchao.quantization.quant_api import quantize diff --git a/test/test_ops.py b/test/test_ops.py index 58c0c571bb..57f4d332d0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,7 +7,7 @@ ) from torch.testing._internal.optests import opcheck from torchao.utils import is_fbcode -from torchao.prototype.fp6_llm.fp6_llm import _from_tc_fpx +from torchao.prototype.quant_llm import from_scaled_tc_fpx import pytest if is_fbcode(): @@ -53,7 +53,7 @@ def test_fp6_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) - fp16_weight = _from_tc_fpx(fpx_weight.view(torch.uint8), ebits, mbits).half() * scale[:, None] + fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half() results_fp16 = fp16_act @ fp16_weight.T error = (results_fpx - results_fp16).abs().mean() diff --git a/torchao/prototype/fp6_llm/__init__.py b/torchao/prototype/fp6_llm/__init__.py deleted file mode 100644 index 7ed84c9dcc..0000000000 --- a/torchao/prototype/fp6_llm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fp6_llm import QuantLlmLinearWeight, fp6_llm_weight_only, quant_llm_fpx_weight_only diff --git a/torchao/prototype/fp6_llm/README.md b/torchao/prototype/quant_llm/README.md similarity index 93% rename from torchao/prototype/fp6_llm/README.md rename to torchao/prototype/quant_llm/README.md index fc20aa626c..f858ef86eb 100644 --- a/torchao/prototype/fp6_llm/README.md +++ b/torchao/prototype/quant_llm/README.md @@ -6,7 +6,7 @@ This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [F ```python from torchao.quantization.quant_api import quantize -from torchao.prototype.fp6_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only +from torchao.prototype.quant_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only model = ... model.half() # not necessary, but recommeneded to maintain accuracy @@ -23,7 +23,7 @@ It's also possible to pre-process the weight and call the kernel directly. ```python import torch -from torchao.prototype.fp6_llm.fp6_llm import _to_scaled_tc_fpx +from torchao.prototype.quant_llm import to_scaled_tc_fpx from torchao.ops import quant_llm_linear fp32_weight = torch.randn(1024, 512).cuda() @@ -31,7 +31,7 @@ ebits, mbits = 3, 2 # pre-process the weight. this will quantize the weight to FP6 and pack it in a special # layout for tensor cores. refer to paper for more details. -fp6_weight, scales = _to_scaled_tc_fpx(fp32_weight, ebits, mbits) +fp6_weight, scales = to_scaled_tc_fpx(fp32_weight, ebits, mbits) fp16_act = torch.randn(1, 512).cuda().half() outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape (1, 1024) diff --git a/torchao/prototype/quant_llm/__init__.py b/torchao/prototype/quant_llm/__init__.py new file mode 100644 index 0000000000..4f1479c401 --- /dev/null +++ b/torchao/prototype/quant_llm/__init__.py @@ -0,0 +1 @@ +from .quant_llm import QuantLlmLinearWeight, fp6_llm_weight_only, quant_llm_fpx_weight_only, to_scaled_tc_fpx, from_scaled_tc_fpx diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/quant_llm/quant_llm.py similarity index 99% rename from torchao/prototype/fp6_llm/fp6_llm.py rename to torchao/prototype/quant_llm/quant_llm.py index c336680b73..40c3f23232 100644 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -410,7 +410,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs): @QuantLlmLinearWeight._implements(torch.nn.functional.linear) def _(*args, **kwargs): - act, weight, bias = args + act = args[0] + weight = args[1] + bias = args[2] if len(args) >= 3 else None assert isinstance(weight, QuantLlmLinearWeight) out_dim, in_dim = weight.shape From 7b822ef1dd6fddf7d8f3c338f51374b6b7fad782 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 22:32:34 +0800 Subject: [PATCH 26/31] update docs --- README.md | 4 ++-- torchao/prototype/README.md | 2 +- torchao/prototype/quant_llm/README.md | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 736915463f..90fb105599 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear}) * [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet. * [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701) -* [fp6](torchao/prototype/fp6_llm/) for 2x faster inference over fp16 with an easy to use wrapper api `convert_fp6_llm(model)` +* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())` ## Composability @@ -104,7 +104,7 @@ python setup.py install * [GaLore](torchao/prototype/galore/) a drop for the Adam Optimizer that allows you to finetune llama 7b on a single 4090 card with up to 70% speedups relative to eager PyTorch * [DoRA](torchao/prototype/dora) a newer replacement for QLoRA with more promising convergence characteristics * [Fused int4/fp16 Quant Matmul](torchao/prototype/hqq) which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512 -* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/fp6_llm](torchao/prototype/fp6_llm) +* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/quant_llm](torchao/prototype/quant_llm) * [vayuda](https://github.com/vayuda) with generic bitpacking kernels that were code generated using pure PyTorch [prototype/common](torchao/prototype/common) * [andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uint2 and fully code-generated with torch.compile diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index 633099368a..65968ad3e5 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -9,7 +9,7 @@ - `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507) - `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm - `galore/docs` - implementation notes and discussion of issues faced in kernel design. -- [`fp6_llm`](fp6_llm) - FP16 x FP6 mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) +- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) #### Roadmap diff --git a/torchao/prototype/quant_llm/README.md b/torchao/prototype/quant_llm/README.md index f858ef86eb..631df30817 100644 --- a/torchao/prototype/quant_llm/README.md +++ b/torchao/prototype/quant_llm/README.md @@ -40,8 +40,9 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape **NOTE**: - Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations. - Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1. +- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 for some microbenchmark results. -## Benchmark results +## End-to-End benchmarks Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in [_models/llama](../../_models/llama). tokens/s is measured using [generate.py](../../_models/llama/generate.py) which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using [eval.py](../../_models/llama/eval.py) which uses [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness). The model used is [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). From ca45dda520c3e6d84eedf3c0905d43fa703df6fd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 22:57:30 +0800 Subject: [PATCH 27/31] rename --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 57f4d332d0..28e7437b66 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -47,7 +47,7 @@ def test_quant_llm_linear(self, ebits, mbits): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - def test_fp6_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): + def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") From 36fe61e8157bd1485242343aaf3d6499e56bf1b3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 23:07:05 +0800 Subject: [PATCH 28/31] fix for PyTorch 2.2 --- torchao/prototype/quant_llm/quant_llm.py | 28 +++++++++++++++--------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/torchao/prototype/quant_llm/quant_llm.py b/torchao/prototype/quant_llm/quant_llm.py index 40c3f23232..fe5e4d7958 100644 --- a/torchao/prototype/quant_llm/quant_llm.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -29,15 +29,23 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: x = _unpack(x, n_bits) x = x.view(-1, 4 * (8 // n_bits)) - bit_order = { - 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, - 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], - 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], - 4: [1, 5, 3, 7, 0, 4, 2, 6], - }[n_bits] - - if undo: - bit_order = [bit_order.index(i) for i in range(len(bit_order))] + if not undo: + bit_order = { + 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, + 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], + 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], + 4: [1, 5, 3, 7, 0, 4, 2, 6], + }[n_bits] + + else: + # this is inverse of the above, obtained by running + # [v.index(i) for i in range(len(v))] + bit_order = { + 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, + 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], + 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], + 4: [4, 0, 6, 2, 5, 1, 7, 3], + }[n_bits] x = x[:, bit_order] x = _pack(x, n_bits) @@ -363,7 +371,7 @@ def __tensor_flatten__(self): return ["fpx_data", "scale"], [self.ebits, self.mbits] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride): + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): return cls(tensor_data_dict["fpx_data"], tensor_data_dict["scale"], *tensor_attributes) @classmethod From 57ad040b7e3605f6d42cf444d90d206c3d6a3305 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jun 2024 22:02:44 +0000 Subject: [PATCH 29/31] _implements -> implements --- torchao/prototype/quant_llm/quant_llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/quant_llm/quant_llm.py b/torchao/prototype/quant_llm/quant_llm.py index fe5e4d7958..eaa5c596f5 100644 --- a/torchao/prototype/quant_llm/quant_llm.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -346,7 +346,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te class QuantLlmLinearWeight(Tensor): - _implements = classmethod(_implements) + implements = classmethod(_implements) @staticmethod def __new__(cls, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): @@ -416,7 +416,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise NotImplementedError(f"{cls.name} dispatch: attempting to run {func}, this is not supported") -@QuantLlmLinearWeight._implements(torch.nn.functional.linear) +@QuantLlmLinearWeight.implements(torch.nn.functional.linear) def _(*args, **kwargs): act = args[0] weight = args[1] @@ -445,7 +445,7 @@ def _(*args, **kwargs): return out.view(*act.shape[:-1], out_dim).to(act.dtype) -@QuantLlmLinearWeight._implements(torch.ops.aten.detach.default) +@QuantLlmLinearWeight.implements(torch.ops.aten.detach.default) def _(func, *args, **kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) From ceaa71c3fd0c53e76cc816b1c1492a110821184e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 25 Jun 2024 20:15:49 +0800 Subject: [PATCH 30/31] set CUDA context --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 44cd2e39a2..1d44acde08 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -116,6 +116,7 @@ cudaError_t fpx_linear_kernel(cudaStream_t stream, #include #include +#include #include namespace torchao { @@ -166,23 +167,27 @@ torch::Tensor fp_eXmY_linear_forward_cuda( at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default stream (0) + // this fixes problem with CUDA graphs when used with torch.compile() + auto stream = at::cuda::getCurrentCUDAStream(); + // officially supported in Quant-LLM if (EXPONENT == 3 && MANTISSA == 2) - fpx_linear_kernel<3, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 2 && MANTISSA == 2) - fpx_linear_kernel<2, 2>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // experimental else if (EXPONENT == 2 && MANTISSA == 3) - fpx_linear_kernel<2, 3>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else if (EXPONENT == 3 && MANTISSA == 1) - fpx_linear_kernel<3, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // else if (EXPONENT == 2 && MANTISSA == 1) - // fpx_linear_kernel<2, 1>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // else if (EXPONENT == 3 && MANTISSA == 0) - // fpx_linear_kernel<3, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); // else if (EXPONENT == 2 && MANTISSA == 0) - // fpx_linear_kernel<2, 0>(0, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); else TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); From 4e585e9cc78c02718178375b83eb753a4b8c13bb Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 25 Jun 2024 20:15:56 +0800 Subject: [PATCH 31/31] fix __repr__ --- torchao/prototype/quant_llm/quant_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/quant_llm/quant_llm.py b/torchao/prototype/quant_llm/quant_llm.py index eaa5c596f5..8e4fae465d 100644 --- a/torchao/prototype/quant_llm/quant_llm.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -386,7 +386,7 @@ def dequantize(self, output_dtype=None): def __repr__(self): dtype = f"fp{1 + self.ebits + self.mbits}_e{self.ebits}m{self.mbits}" return ( - f"{self.__class__.name}(dtype={dtype}, shape={self.shape}, " + f"{self.__class__.__name__}(dtype={dtype}, shape={self.shape}, " f"device={self.device}, requires_grad={self.requires_grad})" )