diff --git a/CMakeLists.txt b/CMakeLists.txt index be0d3555f..0213ffc2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ # For GCC: `cmake -B build . && cmake --build build` # For MSVC: `cmake -B build . && cmake --build build --config Release` # You can also use the following options and variables -# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `mps`, or `sycl` to select the backend # - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # is whatever CMake finds on your path. @@ -11,7 +11,7 @@ # Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90` # Check your compute capability here: https://developer.nvidia.com/cuda-gpus # - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler -cmake_minimum_required(VERSION 3.22.1) +cmake_minimum_required(VERSION 3.20.4) project(bitsandbytes LANGUAGES CXX) @@ -24,15 +24,18 @@ if(NOT CMAKE_BUILD_TYPE) endif() # Define included source files -set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) +set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) +set(SYCL_FILES csrc/sycl/kernels.cpp csrc/sycl/ops.cpp csrc/pythonInterface.cpp) +#set(SYCL_FILES csrc/sycl/kernel_gemm.cpp csrc/sycl/op_gemm.cpp csrc/sycl/kernel_quant.cpp csrc/sycl/op_quant.cpp) + # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps, sycl)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps sycl) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -50,6 +53,7 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") option(NO_CUBLASLT "Disable CUBLAS" OFF) set(BUILD_CUDA ON) set(BUILD_MPS OFF) + set(BUILD_SYCL OFF) message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) @@ -57,9 +61,15 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps") endif() set(BUILD_CUDA OFF) set(BUILD_MPS ON) + set(BUILD_SYCL OFF) +elseif(${COMPUTE_BACKEND} STREQUAL "sycl") + set(BUILD_CUDA OFF) + set(BUILD_SYCL ON) + set(BUILD_MPS OFF) else() set(BUILD_CUDA OFF) set(BUILD_MPS OFF) + set(BUILD_SYCL OFF) endif() @@ -177,12 +187,31 @@ elseif(BUILD_MPS) COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_SYCL) + if ( NOT DEFINED ENV{ONEAPI_ROOT}) + message(FATAL_ERROR "Not detect ENV {ONEAPI_ROOT}, please install oneAPI & source it, like: source /opt/intel/oneapi/setvars.sh") + endif() + find_package(IntelSYCL REQUIRED) + set(CMAKE_CXX_STANDARD 17) + add_compile_options(-I./) #include DPCT + add_compile_options(-I/${SYCL_INCLUDE_DIR}) + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") + if (SYCL_TARGET STREQUAL "INTEL") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") + elseif( SYCL_TARGET STREQUAL "NVIDIA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") + list(APPEND SRC_FILES ${SYCL_FILES}) + else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) endif() - - + if(WIN32) # Export all symbols set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) @@ -195,9 +224,12 @@ endif() set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) add_library(bitsandbytes SHARED ${SRC_FILES}) -target_compile_features(bitsandbytes PUBLIC cxx_std_14) -target_include_directories(bitsandbytes PUBLIC csrc include) - +if(BUILD_SYCL) + target_compile_features(bitsandbytes PUBLIC cxx_std_17) +else() + target_compile_features(bitsandbytes PUBLIC cxx_std_14) +endif() +target_include_directories(bitsandbytes PUBLIC csrc csrc/sycl include) if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) @@ -218,6 +250,13 @@ if(BUILD_MPS) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() +if(BUILD_SYCL) + if (SYCL_TARGET STREQUAL "INTEL") + target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_intel_ilp64 mkl_tbb_thread dnnl) + elseif(SYCL_TARGET STREQUAL "NVIDIA") + target_link_libraries(bitsandbytes PUBLIC onemkl pthread m dl) + endif() +endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") endif() diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp new file mode 100644 index 000000000..c2a4559d1 --- /dev/null +++ b/csrc/sycl/kernels.cpp @@ -0,0 +1,5015 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include +#include +#include +#include "kernels.h" +#include +//#include +#include + +#define FLT_MAX std::numeric_limits::max() +#define FLT_MIN std::numeric_limits::min() +#include "utilities.h" + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +#ifdef BLOCK_SIZE +#undef BLOCK_SIZE +#endif + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +float atomicMax(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = dpct::atomic_compare_exchange_strong(reinterpret_cast(address), assumed, sycl::bit_cast(sycl::fmax(val, sycl::bit_cast(assumed)))); + } while (assumed != old); + return sycl::bit_cast(old); +} + +float atomicMin(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = dpct::atomic_compare_exchange_strong(reinterpret_cast(address), assumed, sycl::bit_cast(sycl::fmin(val, sycl::bit_cast(assumed)))); + } while (assumed != old); + return sycl::bit_cast(old); +} + +float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + +float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assume input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to notice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = sycl::fabs(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +sycl::half dhDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = sycl::fabs(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = sycl::fabs(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__dpct_inline__ unsigned char quantize_2D(float *__restrict__ quadrants, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : 0;//smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +template +__dpct_inline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper) +{ + int lower_pivot = QUADRANT*16-1 - 0; + int pivot = QUADRANT*16-1 + 16; + int upper_pivot = QUADRANT*16-1 + 31; + + float val = midpoint; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 16; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} +//=====================================================NON GEMMS================================ + +//=====================================histogram 2d==================== +SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, + const sycl::nd_item<3> &item_ct1, const sycl_dacc_float &dacc_histogram, const sycl_dacc &dacc_index1, + const sycl_dacc &dacc_index2, const sycl_dacc_float &dacc_src) +{ + const int tid = item_ct1.get_local_id(2) + (item_ct1.get_local_range(2)*item_ct1.get_group(2)); + const int numThreads = item_ct1.get_local_range(2)*item_ct1.get_group_range(2); + + for(int i = tid; i < n; i+=numThreads) + { + int idx = (dacc_index1[i]*maxidx1) + dacc_index2[i]; + dpct::atomic_fetch_add(&dacc_histogram[idx], dacc_src[i]); + } +} + + + +//===========================k compress max========================== + +template +void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n, + const sycl::nd_item<3> &item_ct1, int *smem_max_indices, + float *smem_max_values) +{ + + + const int warp_idx = item_ct1.get_local_id(2)/32; + const int valid_items = n - (item_ct1.get_group(2)*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (item_ct1.get_group(2)*BLOCK_SIZE); + + // BLOCK_SIZE/32 == number of warps + T values[8]; + T max1 = -64000.0f; + T max2 = -64000.0f; + int max_idx1 = -1; + int max_idx2 = -1; + int sign1 = -1; + int sign2 = -1; + + sycl::buffer buff_indices(smem_max_indices, sycl::range<1>(8*BLOCK_SIZE/32)); + sycl::buffer buff_values(smem_max_values, sycl::range<1>(8*BLOCK_SIZE/32)); + sycl::buffer buff_A(A,sycl::range<1>(8*BLOCK_SIZE/32)); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load<8, dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, int, int *, sycl::nd_item<3>>; + size_t temp_storage_size = group_load::get_local_memory_size(8*BLOCK_SIZE/32); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A[(item_ct1.get_local_id(2)*BLOCK_SIZE)], cgh, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + [=](sycl::nd_item<3> item) { + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_A, values); + + }); + + }); + + #pragma unroll 8 + for(int i = 0; i < 8; i++) + { + T absval = fabsf(values[i]); + if(absval > max1) + { + max1 = values[i]; + sign1 = signbit(values[i]); + max_idx1 = 8*item_ct1.get_local_id(2) + i; + } + else if(absval > max2) + { + max2 = values[i]; + sign2 = signbit(values[i]); + max_idx2 = 8*item_ct1.get_local_id(2) + i; + } + } + + float warp_max; + sycl::host_accessor hacc_values{buff_values}; + sycl::host_accessor hacc_indices{buff_indices}; + for(int i = 0; i < 8; i++) + { + // 3. do warp reduction + broadcast back + + auto output = sycl::reduce_over_group(item_ct1.get_sub_group(), max1, sycl::maximum<>()); + warp_max = item_ct1.get_sub_group().shuffle(warp_max, 0); + + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + if(warp_max == max1) + { + + hacc_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; + hacc_indices[warp_idx*8 + i] = max_idx1; + + sign1 = sign2; + max1 = max2; + max_idx1 = max_idx2; + + max2 = -64000.0f; + } + sycl::group_barrier(item_ct1.get_sub_group()); + } + + if(item_ct1.get_local_id(2) % 32 < 8) + { + // offset: 8 values per 256 input values + // + int offset = BLOCK_SIZE*item_ct1.get_group(2)*BLOCK_SIZE/32*8; + } + +} + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + + +//================typedefs=================================== + +typedef sycl::local_accessor sycl_la; +typedef sycl::accessor sycl_dacc; +typedef sycl::accessor sycl_dacc_float; +typedef sycl::accessor sycl_dacc_uc; +typedef sycl::accessor sycl_dacc_char; + +//======================estimte quantiles===================================== + +template +SYCL_EXTERNAL +void kEstimateQuantiles(T*__restrict__ const A, float *code, const float offset, const T max_val, const int n, + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code) +{ + const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); + int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; + const int base_idx = (item_ct1.get_group(2) * BLOCK_ESTIMATE); + const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); + + using group_load = dpct::group::workgroup_load>; + //using group_radix_sort = dpct::group::radix_sort; + + T vals[NUM_ESTIMATE]; + auto *d_A = dacc_A.template get_multi_ptr().get(); + + + int smem_qidx[BLOCK_ESTIMATE]; + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_ESTIMATE) + { + valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; + + // do not process half-blocks + if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = max_val; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_A, vals); + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + + + + item_ct1.barrier(); + // sort into striped pattern to mitigate bank conflicts + // striped pattern index for thread 0 [0, 1024, 2048, 3096] + // striped pattern index for thread 1 [1, 1025, 2049, 3097] + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + // bypass sorting + //group_radix_sort(tmp).sort_blocked_to_striped(item_ct1, vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + for(int j = item_ct1.get_local_id(2); j < BLOCK_ESTIMATE; j+=item_ct1.get_local_range(2)) + smem_qidx[j] = -1; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + if(item_ct1.get_local_id(2) < 256) + { + float q_interval = (1.0f-(2.0f*offset))/255.0f; + + int local_idx = sycl::round(((offset+(item_ct1.get_local_id(2)*q_interval))*(valid_items-1))); + smem_qidx[local_idx] = item_ct1.get_local_id(2); + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for(int i = item_ct1.get_local_id(2); i < BLOCK_ESTIMATE; i+=item_ct1.get_local_range(2)) + { + if(smem_qidx[i] != -1) + dpct::atomic_fetch_add(&dacc_code[smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + } + + } + +} + +//====================================k quantize=========================================== +SYCL_EXTERNAL +void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, + const sycl::nd_item<3> &item_ct1, float* smem_code, const sycl_la &tacc, const sycl_dacc_float &dacc_A, + const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (item_ct1.get_group(2) * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + using group_load_float = dpct::group::workgroup_load>; + using group_store_uc = dpct::group::workgroup_store>; + + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *d_out = dacc_out.get_multi_ptr().get(); + + if(item_ct1.get_local_id(2) < 256) + { + smem_code[item_ct1.get_local_id(2)] = dacc_code[item_ct1.get_local_id(2)]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load_float(tmp).load(item_ct1, d_A, vals); + + #pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + //1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_uc(tmp).store(item_ct1, d_out, qvals); + } +} + + + +//===========================k quantize blockwise================================ + +template +//__launch_bounds__(TH, 4) +SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code, + float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, + const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, + const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax) +{ + + + const int n_full = item_ct1.get_group_range(2) * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (item_ct1.get_group(2) * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + using group_store_uc = dpct::group::workgroup_store<(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, dpct::group::store_algorithm::BLOCK_STORE_DIRECT, unsigned char, unsigned char *, sycl::nd_item<3>>; + + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *d_rand = dacc_rand.get_multi_ptr().get(); + auto *d_out = dacc_out.get_multi_ptr().get(); + + //code //absmax + + if(DATA_TYPE == General8bit) + for(int i = item_ct1.get_local_id(2); i < 256; i+=item_ct1.get_local_range(2)) + smem_code[i] = dacc_code[i]; + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_A, vals); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = sycl::fmax(local_abs_max, sycl::fabs((float)vals[j])); + + local_abs_max = sycl::reduce_over_group(item_ct1.get_group(), local_abs_max, sycl::maximum<>()); + + if(item_ct1.get_local_id(2) == 0) + smem_absmax_value[0] = local_abs_max; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + if(item_ct1.get_local_id(2) == 0) + dacc_absmax[i/BLOCK_SIZE] = local_abs_max; + else + local_abs_max = smem_absmax_value[0]; + + sycl::group_barrier(item_ct1.get_sub_group()); + + local_abs_max = 1.0f/local_abs_max; + + if(STOCHASTIC) + { + local_rand_idx = ((item_ct1.get_group(2)*NUM_BLOCK) + (item_ct1.get_local_id(2)*NUM) + rand_offset) % (1024-4); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_float(tmp).load(item_ct1, d_rand, rand_vals); + + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_uc(tmp).store(item_ct1, d_out, qvals); + + } +} + +//===========================k dequantize blockwise================================ + + +template +SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax ) +{ + + const int n_load = (item_ct1.get_group_range(2) * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (item_ct1.get_group(2) * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store 0) ? 2 : 1), dpct::group::store_algorithm::BLOCK_STORE_DIRECT, T, T *, sycl::nd_item<3>>; + + + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *d_out = dacc_out.template get_multi_ptr().get(); + //A //out //code //absmax + + //typedef cub::BlockLoad LoadChar; + //typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + + for (unsigned int i = base_idx; i < n_load; i += item_ct1.get_group_range(2)*TILE_SIZE) + { + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + + local_abs_max = dacc_absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)];//sycl::ext::oneapi::experimental::cuda::ldg(&absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)]); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load_uc(tmp).load(item_ct1, d_A, qvals); + + + + + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + + vals[j] = dacc_code[qvals[j]]*local_abs_max;//sycl::ext::oneapi::experimental::cuda::ldg(&code[qvals[j]]*local_abs_max); + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + + item_ct1.barrier(); + + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + group_store(tmp).store(item_ct1, d_out, vals); + + } +} +//=========================k dequantize====================== + +SYCL_EXTERNAL void kDequantize(float *code, unsigned char *buff_A, float *buff_out, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code) +{ + const unsigned int numThreads = item_ct1.get_local_range(2) * item_ct1.get_group_range(2); + const int idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); + + + if(item_ct1.get_local_id(2) < 256) + { + smem_code[item_ct1.get_local_id(2)] = code[item_ct1.get_local_id(2)]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (int i = idx;i < n; i += numThreads) + { + buff_out[i] = smem_code[buff_A[i]]; + } +} + + + +//===================32 bit optimizer======================== + +/* +DPCT1110:1: The total declared local variable size in device function kPreconditionOptimizer32bit2State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +template +SYCL_EXTERNAL +void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g, const sycl_dacc_float &dacc_unorm) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f/(1.0f - dpct::pow(beta1, step)); + const float correction2 = 1.0f/(1.0f - dpct::pow(beta2, step)); + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); + + + + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_float(tmp).load(item_ct1, d_state1, s1_vals); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_float(tmp).load(item_ct1, d_state2, s2_vals); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j]/(sycl::sqrt(s2_vals[j])+eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); + + if(item_ct1.get_local_id(2) == 0) + dpct::atomic_fetch_add(&dacc_unorm[0], s1_vals[0]); + + sycl::group_barrier(item_ct1.get_sub_group()); + } +} + +#define NUM_PER_THREAD 4 + +template +SYCL_EXTERNAL +void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - dpct::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f - dpct::pow(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sycl::sqrt(dacc_unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_float = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); + + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g , g_vals); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_float(tmp).load(item_ct1, d_state1, s1_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + group_load_float(tmp).load(item_ct1, d_state2, s2_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + group_load(tmp).load(item_ct1, d_p, p_vals); + + + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sycl::sqrt(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store(tmp).store(item_ct1, d_p , p_vals); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_float(tmp).store(item_ct1, d_state1, s1_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_float(tmp).store(item_ct1, d_state2, s2_vals); + + + } +} + +template +SYCL_EXTERNAL +void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1, + const sycl_dacc_float &dacc_unorm) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + + + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + item_ct1.barrier(); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_float(tmp).load(item_ct1, d_state1, s1_vals); + + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = (float)g_vals[j] / (sycl::sqrt(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = (float)g_vals[j] / (sycl::sqrt(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + //s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); + + if(item_ct1.get_local_id(2) == 0) + dpct::atomic_fetch_add(&dacc_unorm[0], s1_vals[0]); + + sycl::group_barrier(item_ct1.get_sub_group()); + } +} + + +template +SYCL_EXTERNAL +void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const + sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_unorm) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sycl::sqrt(dacc_unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_float = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + + item_ct1.barrier(); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + + + item_ct1.barrier(); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_float(tmp).load(item_ct1, d_state1, s1_vals); + + + + item_ct1.barrier(); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load(tmp).load(item_ct1, d_p, p_vals); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*(float)g_vals[j] / (sycl::sqrt((float)s1_vals[j])+eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*(float)g_vals[j] / (sycl::sqrt((float)s1_vals[j])+eps); + break; + } + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store(tmp).store(item_ct1, d_p, p_vals); + + + item_ct1.barrier(); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_float(tmp).store(item_ct1, d_state1, s1_vals); + + + } +} + +//===================8 bit optimizer======================== + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template + +SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, + float* smem_quantiles1, float* smem_quantiles2, + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, + const sycl_dacc_float &dacc_new_max2) +{ + const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = n - (item_ct1.get_group(2)*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (item_ct1.get_group(2)*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); + + if(item_ct1.get_local_id(2) < 256) + { + smem_quantiles1[item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles2[item_ct1.get_local_id(2)] = dacc_quantiles2[item_ct1.get_local_id(2)]; + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*item_ct1.get_group_range(2)*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state1, m_c1); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state2, r_c2); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*dacc_max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = sycl::fmax(local_max_s1, sycl::fabs(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*dacc_max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = sycl::fmax(local_max_s2, sycl::fabs(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = 1.0f / (1.0f - dpct::pow(beta1, step)); + float correction2 = 1.0f / (1.0f - dpct::pow(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j]/(sycl::sqrt(s2_vals[j])+eps); // update + local_unorm += update_val*update_val; + } + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + local_max_s1 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s1, sycl::maximum<>()); + + item_ct1.barrier(sycl::access::fence_space::local_space); + + local_max_s2 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s2, sycl::maximum<>()); + if(unorm != NULL) + { + + item_ct1.barrier(sycl::access::fence_space::local_space); + + local_unorm = sycl::reduce_over_group(item_ct1.get_group(), local_unorm, sycl::plus<>()); + } + + if(item_ct1.get_local_id(2) == 0) + { + atomicMax(&dacc_new_max1[0], local_max_s1); + atomicMax(&dacc_new_max2[0], local_max_s2); + if(unorm != NULL){ dpct::atomic_fetch_add(&dacc_unorm[0], local_unorm); } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + + +template +SYCL_EXTERNAL void + +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float* smem_quantiles1, float* smem_quantiles2, + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, const sycl_dacc_float &dacc_unorm, + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, + const sycl_dacc_float &dacc_new_max2 + ) +{ + + const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - dpct::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f - dpct::pow(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/dacc_new_max1[0]; + float new_max_val2 = 1.0f/dacc_new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sycl::sqrt((float)(dacc_unorm[0])) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); + + + if(item_ct1.get_local_id(2) < 512) + { + if(item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; + else + smem_quantiles2[item_ct1.get_local_id(2)-256] = dacc_quantiles2[item_ct1.get_local_id(2)-256]; + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state1, c1s); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state2, c2s); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load(tmp).load(item_ct1, d_p, p_vals); + + if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*dacc_max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(sycl::signbit(smem_quantiles1[c1s[j]]) != sycl::signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*dacc_max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sycl::sqrt(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store(tmp).store(item_ct1, d_p, p_vals); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_uc(tmp).store(item_ct1, d_state1, c1s); + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_uc(tmp).store(item_ct1, d_state2, c2s); + + item_ct1.barrier(sycl::access::fence_space::local_space); + } +} + +template +SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, + float* smem_quantiles1, + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1) +{ + const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = n - (item_ct1.get_group(2)*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (item_ct1.get_group(2)*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + + if(item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state1, m_c1); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*dacc_max1[0]; + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = sycl::fmax(local_max_s1, sycl::fabs(s1_vals[j])); + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + local_max_s1 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s1, sycl::maximum<>()); + if(item_ct1.get_local_id(2) == 0){ atomicMax(&dacc_new_max1[0], local_max_s1); } + if(unorm != NULL) + { + + item_ct1.barrier(sycl::access::fence_space::local_space); + + local_unorm = sycl::reduce_over_group(item_ct1.get_group(), local_unorm, sycl::plus<>()); + if(item_ct1.get_local_id(2) == 0){ dpct::atomic_fetch_add(&dacc_unorm[0], local_unorm); } + } + +} + +template +SYCL_EXTERNAL void + +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1,float *smem_quantiles1, const sycl_la tacc, + const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1) +{ + + const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/dacc_new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sycl::sqrt((float)(dacc_unorm[0])) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + + + + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + + + + if(item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state1, c1s); + + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load(tmp).load(item_ct1, d_p, p_vals); + + if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*dacc_max1[0]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr*g_val / (sycl::sqrt(s1_vals[j])+eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if(sycl::signbit(smem_quantiles1[c1s[j]]) != sycl::signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store(tmp).store(item_ct1, d_p, p_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_uc(tmp).store(item_ct1, d_state1, c1s); + + item_ct1.barrier(sycl::access::fence_space::local_space); + } +} + + +//===============================k percentile clipping============================================ + + +template +SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl::accessor &dacc_g, float *dacc_gnorm_vec) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + using group_load = dpct::group::workgroup_load>; + auto *d_g = dacc_g.template get_multi_ptr().get(); + + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (item_ct1.get_group(2) * BLOCK_SIZE); i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, vals); + + #pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + + local_sum = sycl::reduce_over_group(item_ct1.get_group(), local_sum, sycl::plus<>()); + + if(item_ct1.get_local_id(2) == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + dpct::atomic_fetch_add(&dacc_gnorm_vec[j], local_sum); + } + else + dpct::atomic_fetch_add(&dacc_gnorm_vec[step % 100], local_sum); + } + + } +} + + +//=========================8 bit blockwise==================================== + + +#define LANES 2 +#define QUAD 3 +template + +SYCL_EXTERNAL +void +kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + sycl::local_accessor smem_quantiles2, + float *smem_exchange1, float *smem_exchange2, + const sycl_la &tacc, const sycl::accessor &dacc_g, + const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_absmax1, const sycl_dacc_float &dacc_absmax2) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = item_ct1.get_group_range(2) * BLOCK_SIZE; + const int base_idx = (item_ct1.get_group(2) * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + // 2-5% + const float correction1 = 1.0f - dpct::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f -dpct::pow(beta2, step)); + const float step_size = (-lr*correction2) / correction1; + const int lane_id = item_ct1.get_local_id(2) % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + + + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); + + //quantiles1 //quantiles2 //absmax1 //absmax2 + + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles2[0][item_ct1.get_local_id(2)] = dacc_quantiles2[item_ct1.get_local_id(2)]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][item_ct1.get_local_id(2)] = smem_quantiles1[0][item_ct1.get_local_id(2)]; + smem_quantiles2[j][item_ct1.get_local_id(2)] = smem_quantiles2[0][item_ct1.get_local_id(2)]; + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state1, c1s); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state2, c2s); + + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!sycl::isnan((float)g_vals[j]) && !sycl::isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*dacc_absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*dacc_absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } + + new_local_abs_max1 = sycl::fmax(new_local_abs_max1, sycl::fabs(s1_vals[j])); + new_local_abs_max2 = sycl::fmax(new_local_abs_max2, sycl::fabs(s2_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = sycl::reduce_over_group(item_ct1.get_group(), new_local_abs_max1, sycl::maximum<>()); + new_local_abs_max2 = sycl::reduce_over_group(item_ct1.get_group(), new_local_abs_max2, sycl::maximum<>()); + + if(item_ct1.get_local_id(2) == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + if(item_ct1.get_local_id(2) == 0) + { + dacc_absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + dacc_absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load(tmp).load(item_ct1, d_p, p_vals); + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!sycl::isnan((float)g_vals[j]) && !sycl::isinf((float)g_vals[j])) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(s1_vals[j] / (sycl::sqrt(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store(tmp).store(item_ct1, d_p, p_vals); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, s1_vals[j] / new_local_abs_max1); + c2s[j] = quantize_2D<0>(quadrants2, s2_vals[j] / new_local_abs_max2); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != sycl::signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_uc(tmp).store(item_ct1, d_state1, c1s); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_uc(tmp).store(item_ct1, d_state2, c2s); + + item_ct1.barrier(sycl::access::fence_space::local_space); + + } +} + + +#define LANES 2 +#define QUAD 3 +template +SYCL_EXTERNAL +void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + float *smem_exchange1, + const sycl_la &tacc, + const sycl::accessor &dacc_g, + const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_absmax1 + ) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = item_ct1.get_group_range(2) * BLOCK_SIZE; + const int base_idx = (item_ct1.get_group(2) * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = item_ct1.get_local_id(2) % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + + + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][item_ct1.get_local_id(2)] = smem_quantiles1[0][item_ct1.get_local_id(2)]; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + /* + DPCT1065:192: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. + */ + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_uc(tmp).load(item_ct1, d_state1, c1s); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load(tmp).load(item_ct1, d_p, p_vals); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*dacc_absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = sycl::fmax(new_local_abs_max1, sycl::fabs(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = sycl::reduce_over_group(item_ct1.get_group(), new_local_abs_max1, sycl::maximum<>()); + + if(item_ct1.get_local_id(2) == 0) + smem_exchange1[0] = new_local_abs_max1; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + if(item_ct1.get_local_id(2) == 0) + dacc_absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(g_val / (sycl::sqrt(s1_vals[j])+eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(g_val / (sycl::sqrt(s1_vals[j])+eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store(tmp).store(item_ct1, d_p, p_vals); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, s1_vals[j] / new_local_abs_max1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != sycl::signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_uc(tmp).store(item_ct1, d_state1, c1s); + + } +} + +//==========================k get row col stats========================================== + +template void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row) +{ + // 0. reset stats to -FLT_MAX + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + // 2. compute col max (per thread); store in smem due to register pressure + // 3. compute row max (per block); store in smem to accumulate full global mem transation + // 4. store data via atomicMax + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((item_ct1.get_group(2)*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (item_ct1.get_group(2)*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + //rowStats //colStats // nnz + + using group_load = dpct::group::workgroup_load>; + using group_exchange = dpct::group::exchange; + + + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + + + sycl::half local_data[ITEMS_PER_THREAD]; + float local_data_fp32[ITEMS_PER_THREAD]; + float local_col_absmax_values[ITEMS_PER_THREAD]; + int local_row_nnz_count = 0; + float row_absmax = -FLT_MAX; + + // 0. reset stats to -FLT_MAX + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_absmax_values[item_ct1.get_local_id(2) + (j*THREADS)] = -FLT_MAX; + // smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; + } + + #pragma unroll TILE_ROWS + for (int j = 0; j < TILE_ROWS; j++) { + smem_row_nnz_values[j] = 0; + } + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_col_absmax_values[j] = -FLT_MAX; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + int i = base_idx; + // we load row after row from the base_position + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row+row >= rows){ break; } + local_row_nnz_count = 0; + i = base_idx + ((row)*cols); + // each thread gets data from the same column + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load(tmp).load(item_ct1, d_A, local_data); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = sycl::fabs(local_data[j]); + + + if(SPARSE_DECOMP) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + if((float)local_data[j] >= nnz_threshold) + { + local_row_nnz_count += 1; + local_data[j] = 0.0f; + } + } + + // 2. compute col max (per thread); store in smem due to register pressure + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + // take the col max for this row + // we use shared memory because register pressure is too high if we do this locally + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); + local_col_absmax_values[j] = sycl::fmax(local_col_absmax_values[j], sycl::vec(local_data[j]).convert()[0]); + + // 3. compute row max (per block); store in smem to accumulate full global mem transation + + // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data_fp32[j] = local_data[j]; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + row_absmax = (float)sycl::reduce_over_group(item_ct1.get_group(), local_data_fp32[0], sycl::maximum<>()); + if(SPARSE_DECOMP) + { + + item_ct1.barrier(sycl::access::fence_space::local_space); + local_row_nnz_count = sycl::reduce_over_group(item_ct1.get_group(), local_row_nnz_count, sycl::plus<>()); + } + // we store the data temporarily in shared memory so we + // can execute a full atomic block transaction into global memory later + // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores + if(item_ct1.get_local_id(2) == 0) + { + smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; + // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block + smem_row_nnz_values[row] = local_row_nnz_count; + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + } + + // 4. store data via atomicMax + // to store col data efficiently we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 + // into a striped arrangement: [0, 8, 16, 24, ..] for t0 + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_exchange(tmp).blocked_to_striped(item_ct1, local_col_absmax_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+item_ct1.get_local_id(2)+(j*THREADS) < cols) + { + float val = dacc_colStats[base_col+(item_ct1.get_local_id(2)+(j*THREADS))]; + if(val < local_col_absmax_values[j]) + atomicMax(&dacc_colStats[base_col+(item_ct1.get_local_id(2)+(j*THREADS))], local_col_absmax_values[j]); + } + + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_row+item_ct1.get_local_id(2)+(j*THREADS) < rows) + { + float val = dacc_rowStats[base_row+(item_ct1.get_local_id(2)+(j*THREADS))]; + if(val < smem_row_absmax_values[item_ct1.get_local_id(2)+(j*THREADS)]) + atomicMax(&dacc_rowStats[base_row+(item_ct1.get_local_id(2)+(j*THREADS))], smem_row_absmax_values[item_ct1.get_local_id(2)+(j*THREADS)]); + } + + if(SPARSE_DECOMP) + if(item_ct1.get_local_id(2) < TILE_ROWS) + dacc_nnz_count_row[item_ct1.get_group(2)*TILE_ROWS+item_ct1.get_local_id(2)+1] = smem_row_nnz_values[item_ct1.get_local_id(2)]; + +} + + +//========================================k dequant mm int32fp16=================== + + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, const sycl::accessor &dacc_bias ) +{ + + // Strategy: To dequantize we need to load col/row statistics. This can be very expensive + // since different row/col stats need to be loaded with each thread. + // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure + // and would lead to low global load utilization. + // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads + // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. + // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. + // This allows for efficient row/col loading from shared memory within the tile. + // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has + // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts + // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the + // shared memory loads. + + // data is in 32 column-tile major with tile width 32 columns and numRows rows + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) + // C2. Compute normalization values and store col values in register + // S1. Store C1 into 16-bit output + // S2. Store col/row statistics of new buffer in shared memory + + // We allow for sub-tiles to span multiple col32 tiles. This is okay + // since the items per thread only rely on a single column statistic. + + + const int n_out = numRows*numCols; + + int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + // we have tiles of size numRows*32, thus col only increases every numRows + // num_row_tiles is the tiles after which the column increases by 32 + // blockIdx.x is the index of the current tile + int col = ((item_ct1.get_local_id(2) % 32) + ((item_ct1.get_group(2)/num_row_tiles)*32)); + // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached + int base_row = (item_ct1.get_group(2)*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS + // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD + // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. + // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have + // 1024*1024/(128*32) = 256 tiles + // 256 tiles are 256*128*32/4 = 256*1024 threads + + // 1. Figure out how index relates to the start of the sub-tile + // 2. Each thread < SUBTILE_ROWS calculates row index + // 3. Load striped and store in shared memory + + int local_values[ITEMS_PER_THREAD]; + sycl::half local_output[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; + + using group_load_int = dpct::group::workgroup_load>; + using group_exchange = dpct::group::exchange; + + auto *d_A = dacc_A.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + //dacc_colStats //dacc_bias //dacc_rowStats + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + float colStat = col >= numCols ? 0.0f : dacc_colStats[col]; + float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : sycl::vec(dacc_bias[col]).convert()[0]; + // no block loads for rows for now -- keep it simple + for(int j = item_ct1.get_local_id(2); j < SUBTILE_ROWS; j+=item_ct1.get_local_range(2)) + { + // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? + int row = (base_row+j) % numRows; // wrap around + // each warp accesses the same element, for four consequitive elements + // todo: update description about striped shared memory, it is not needed + // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements + smem_rowStats[j] = dacc_rowStats[row]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + // each block processes SUBTILE_ROWS*32 elements + const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int rows_per_load = items_per_load/32; + + int subtile_base_row = (item_ct1.get_local_id(2) / 32)*ITEMS_PER_THREAD; // row within the tile + int row_offset = 0; + // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed + int subtile_start = (item_ct1.get_group(2)/num_row_tiles)*(numRows*32) + (base_row*32); + for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) + { + int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); + int valid_items = valid_rows*32; + if(valid_items <= 0) // the sub-tile might have more elements than the tile itself + break; + + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + + //LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load_int(tmp).load(item_ct1, d_A, local_values); + + //ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_exchange(tmp).blocked_to_striped(item_ct1, local_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = sycl::vec((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue).convert()[0]; + //absmax_col = fmax(fabsf(local_output[j]), absmax_col); + + // we store data in row major + // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] + // so that each thread holds ITEMS_PER_THREAD consecutive items for each row + // this way throughput into storage is increased by a factor of ~2x + // for now we use a simple store + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); + if(outIdx< n_out && col < numCols) + dacc_out[outIdx] = local_output[j]; + } + + row_offset += rows_per_load; + } +} +//=====================================k double row col quant============================ + +template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr) +{ + // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD + // Each thread reads the same column but multiple rows + // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + // 2. quantize data with row/col stats + // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((item_ct1.get_group(2)*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (item_ct1.get_group(2)*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + //colStats ,rowStats rowidx, colidx,val ,nnz + + using group_load_half = dpct::group::workgroup_load>; + using group_store_char = dpct::group::workgroup_store>; + + auto *d_A = dacc_A.get_multi_ptr().get(); + auto *d_out_col_normed = dacc_out_col_normed.get_multi_ptr().get(); + auto *d_out_row_normed = dacc_out_row_normed.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + + sycl::half local_data[ITEMS_PER_THREAD]; + float local_col_stats[ITEMS_PER_THREAD]; + char local_quantized_data[ITEMS_PER_THREAD]; + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+(item_ct1.get_local_id(2)*ITEMS_PER_THREAD) + j < cols) + /* + To-do: __fdividef call is used in a macro/template definition and may not be valid for all macro/template uses. + */ + local_col_stats[j] = 127.0f / dacc_colStats[base_col+(item_ct1.get_local_id(2)*ITEMS_PER_THREAD)+j]; + + for(int i = item_ct1.get_local_id(2); i < TILE_ROWS; i+=item_ct1.get_local_range(2)) + { + if(base_row + i < rows) + smem_row_stats[i] = dacc_rowStats[base_row+i]; + + if(SPARSE_DECOMP) + smem_nnz_row_idx[i] = dacc_nnz_block_ptr[(TILE_ROWS*item_ct1.get_group(2)) + i]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // we load row after row from the base_position + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row + row >= rows){ break; } + int i = base_idx + (row*cols); + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_load_half(tmp).load(item_ct1, d_A, local_data); + + float row_stat = 127.0f / smem_row_stats[row]; + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + if(SPARSE_DECOMP) + { + if(sycl::fabs((float)local_data[j]) >= threshold) + { + local_quantized_data[j] = 0; + + int old_idx = dpct::atomic_fetch_compare_inc(&smem_nnz_row_idx[row], UINT_MAX); + + dacc_rowidx[old_idx] = base_row+row; + dacc_colidx[old_idx] = base_col+(item_ct1.get_local_id(2)*ITEMS_PER_THREAD)+j; + dacc_val[old_idx] = local_data[j]; + } + else + { + local_quantized_data[j] = (char)(sycl::rint(sycl::vec(local_data[j]).convert()[0]*row_stat)); + } + } + else + local_quantized_data[j] = (char)(sycl::rint(sycl::vec(local_data[j]).convert()[0]*row_stat)); + } + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_char(tmp).store(item_ct1, d_out_row_normed, local_quantized_data); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + local_quantized_data[j] = (char)(sycl::rint(sycl::vec(local_data[j]).convert()[0]*local_col_stats[j])); + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + //StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_store_char(tmp).store(item_ct1, d_out_col_normed, local_quantized_data); + + } +} + + + +//============================================k transform row format===================================================== + +template SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, +const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out) +{ + + // 0. Load data into 32*32 shared memory tiles + // 1. transpose / reorder in shared memory + // 2. store + + // COL32 FORMAT: + // rows*32 tiles + + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + + + // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values + // As such we need: + // at least 32*4 shared memory tiles for col32; preferably 32*32 + // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 + // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 + // for efficient loading of row major we need to load 128 elements and repeat this 32 items + // this would imply a 32x128 shared memory tile -> 4kb + // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb + // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy + // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough + // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM + // + // to make the shared memory work with that occupancy we might need to union the block loads/stores + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((item_ct1.get_group(2)*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (item_ct1.get_group(2)*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + + // we load 128 bytes per warp with + // 32 rows for transposes that fill col32 types + // so that we can have contiguous stores + + char local_data[ITEMS_PER_THREAD]; + + // we load row after row from the base_position + // Load data row by row + int warps = item_ct1.get_local_range(2)/32; + int warp_id = item_ct1.get_local_id(2)/32; + int warp_lane = item_ct1.get_local_id(2) % 32; + int offset = 0; + + int smem_row = 0; + // each warp loads one row of 128 bytes + for(int row = warp_id; row < TILE_ROWS; row+=warps) + { + int i = base_idx + (row*cols); + // we load up to 128 bytes/items per load + int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; + + // 0. Load data into 32*32 shared memory tiles + if(base_row + row < rows) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int col_idx = warp_lane+(j*32); + if(col_idx < valid_items) + local_data[j] = dacc_A[i+col_idx]; + else + local_data[j] = 0; + } + } + else + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = 0; + } + + if(TRANSPOSE) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int local_col = (32*j)+warp_lane; + //int local_row = row; + // store as 256x32 + smem_data[(local_col*33) + row] = local_data[j]; + } + } + else + { + // treat smem as 32x256, that is 32 rows and 256 columns + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; + } + + + + smem_row += warps; + + // 1. transpose / reorder in shared memory + if(smem_row % 32 == 0) + { + smem_row = 0; + item_ct1.barrier(sycl::access::fence_space::local_space); + + for(int subrow = warp_id; subrow < 32; subrow+=warps) + { + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + + switch(FORMAT) + { + case COL32: + if(TRANSPOSE) + { + // data lies in shared memory in the following way: + // row0 [col0 col1 ... col31] + // row1 [col0 col1 ... col31] + // ... + // + // As such we read consecutive entries with 256 threads (8rows x 32 columns) + // as j increase, the row increase by a factor of 8 + // We load 8 rows per subrow loop, and subrow increase by 8 per loop + // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size outRows*32 and base_row is done in increments of 32 + offset = base_row*outRows; + dacc_out[offset + (base_col + jrow + subrow_loop_row)*32 + item_ct1.get_local_id(2)] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + offset = (base_col/32)*(32*rows); + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + dacc_out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; + } + } + break; + case COL_TURING: + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // + // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 8*32 = 256 elements offset + // for each row offset of 8 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 256*outRows/8*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + // since we process even number of rows with each j (8) and with each subrow (8j) we can determine + // odd or even rows with the warp_id (each warp processes one row) + // the col is warp_lane (max 32 columns per row) and the row warp_id + if(warp_id % 2 == 1) + // odd + offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); + else + // even + offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); + + dacc_out[offset] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + // set offset designates the tile offset among the 8*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 8*32=256 every 8 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) + // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd + // each of these has 32 values in total for 32*4 = 128 as offset if odd + // every set of 4 columns increases the total offset by 16 + // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 + // this happens every 8 rows anew (subrow % 8) + // one writes 4 columns at once that is (col % 4) for the particular index in the subtile + int subcol = warp_lane; + + // add local offset (4x4 sub-tile) + if(subrow % 2 == 1) + // odd + offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); + else + // even + offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); + + dacc_out[offset] = data; + } + } + break; + case COL_AMPERE: + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 32*32 = 1024 elements offset + // for each row offset of 32 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 1024*outRows/32*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + + // same as in the non-transpose case (see below) + // the difference is that now rows = cols + // in this case warp_id = subrow + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset + int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane + dacc_out[offset + (ampere_row*32) + warp_lane] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + + // set offset designates the tile offset among the 32*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 32*32=1024 every 32 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx + dacc_out[offset + (local_row*32) + warp_lane] = data; + } + } + break; + } + } + } + } + } +} + + +//========================================k extract outliers====================== + +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx) +{ + int local_colidx = dacc_idx[item_ct1.get_group(2)]; + + if(FORMAT== COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for(int row = item_ct1.get_local_id(2); row < rowsA; row+= item_ct1.get_local_range(2)) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + char val = dacc_A[offset]; + + int out_idx = (row*idx_size) + item_ct1.get_group(2); + dacc_out[out_idx] = val; + } + } + else if(FORMAT == COL_AMPERE) + { + + for(int row = item_ct1.get_local_id(2); row < rowsA; row+= item_ct1.get_local_range(2)) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = dacc_A[offset]; + int out_idx = (row*idx_size) + item_ct1.get_group(2); + dacc_out[out_idx] = val; + } + } +} + +//=======================kfunc====================== + +template SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, const sycl::nd_item<3> &item_ct1) +{ + for(long i = (item_ct1.get_local_range(2)*item_ct1.get_group(2)) + item_ct1.get_local_id(2); i < n; i+=(item_ct1.get_local_range(2)*item_ct1.get_group_range(2))) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + +//=============================GEMMS=============================================================== + + +//============================================k spmm sparse coo=============================================== +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +template + +SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats, const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, const sycl_dacc &dacc_offset_rowidx, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + + const int count = dacc_max_count[item_ct1.get_group(2)]; + const int local_max_idx = dacc_max_idx[item_ct1.get_group(2)]; + const int offset = local_max_idx == 0 ? 0 : dacc_offset_rowidx[local_max_idx-1]; + const int local_row_idx = dacc_rowidx[offset]; + + const int warp_id = item_ct1.get_local_id(2) / 32; + const int warp_idx = item_ct1.get_local_id(2) % 32; + const int warp_offset = (warp_id*32)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + sycl::half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + sycl::half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + sycl::half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? dacc_values[offset+j] : sycl::vec(0.0f).convert()[0]; + local_colidxA[j] = j < count ? dacc_colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = item_ct1.get_local_id(2); i < SMEM_SIZE; i+=item_ct1.get_local_range(2)) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dacc_dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = dacc_B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + dacc_out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += item_ct1.get_local_range(2)*SPMM_ITEMS; + local_idx_col_B_offset += item_ct1.get_local_range(2)*SPMM_ITEMS; + } +} + + +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with cub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggregate files of C into shared memory block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} + +template inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = (T)zero_value; + } + } +} + +//=======================================gemm_device=================== + +#define WARPS 3 + + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, T *smem_A, T *smem_B, const sycl::accessor &dacc_A, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out) +{ + +#if DPCT_COMPATIBILITY_TEMP >= 750 + + int col_offset = item_ct1.get_group(2) *32; + const int warp_id = item_ct1.get_local_id(2) / 32; + const int half_warp_id = item_ct1.get_local_id(2) / 16; + const int half_warp_lane = item_ct1.get_local_id(2) % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = item_ct1.get_local_range(2)-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + auto d_A = dacc_A.template get_multi_ptr(); + auto d_B = dacc_B.template get_multi_ptr(); + auto sg_size = item_ct1.get_sub_group(); + + + sycl::ext::oneapi::experimental::matrix::joint_matrix a_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix b_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), c_frag, 0.0f); + + //wmma::fragment a_frag; + //wmma::fragment b_frag; + //wmma::fragment c_frag; + //wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + item_ct1.get_local_id(2); + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = dacc_A[idx]; + local_A[1] = dacc_A[idx+(1*val_per_iter)]; + local_A[2] = dacc_A[idx+(2*val_per_iter)]; + local_A[3] = dacc_A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = dacc_B[(col_offset+col)*ldb+idx]; + local_B[col+32] = dacc_B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = dacc_B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = dacc_B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = item_ct1.get_local_range(2)-32; base_idx < K; base_idx+=item_ct1.get_local_range(2)-32) + { + idx = base_idx + item_ct1.get_local_id(2); + + item_ct1.barrier(sycl::access::fence_space::local_space); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = dacc_A[idx]; + local_A[1] = dacc_A[idx+(1*val_per_iter)]; + local_A[2] = dacc_A[idx+(2*val_per_iter)]; + local_A[3] = dacc_A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = dacc_B[(col_offset+col)*ldb+idx]; + local_B[col+32] = dacc_B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = dacc_B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = dacc_B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + + dacc_A[(ticktock*batch_size_warps + k)*a_tile_offset] = smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]; + dacc_B[(ticktock*batch_size_warps + k)*b_tile_offset] = smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]; + d_A = dacc_A.template get_multi_ptr(); + d_B = dacc_B.template get_multi_ptr(); + + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_A, 16); + + + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, b_frag, d_B, 16); + + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); + } + + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = item_ct1.get_local_id(2) % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + dacc_A[(ticktock*batch_size_warps + k)*a_tile_offset] = smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]; + dacc_B[(ticktock*batch_size_warps + k)*b_tile_offset] = smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]; + d_A = dacc_A.template get_multi_ptr(); + d_B = dacc_B.template get_multi_ptr(); + + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_A, 16); + + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, b_frag, d_B, 16); + + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + + //wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sg_size, c_frag, d_A, (size_t)32, sycl::ext::oneapi::experimental::matrix::layout::row_major); + + if(col_offset + warp_lane < M) + dacc_out[col_offset + warp_lane] = dacc_A[warp_lane]; +#endif +} + +//===============================print=========================================== + +template void printnonzero(T *A, int num_values, const char * strval, + const sycl::stream &stream_ct1) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + /* + DPCT1015:52: Output needs adjustment. + */ + stream_ct1 << "%s %i %f\n"; +} + +template void printnonzero(float *A, int num_values, const char*strval, + const sycl::stream &stream_ct1); +template void printnonzero(sycl::half *A, int num_values, const char*strval, + const sycl::stream &stream_ct1); + +//=======================================4 bit gemm=============================== + +const float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; + + +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, unsigned char *smem_B, T *smem_C, +const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out) +{ + +#if DPCT_COMPATIBILITY_TEMP >= 750 + + int col_offset = item_ct1.get_group(2) *32; + const int warp_id = item_ct1.get_local_id(2) / 32; + const int warp_idx = item_ct1.get_local_id(2) % 32; + const int half_warp_id = item_ct1.get_local_id(2) / 16; + const int half_warp_lane = item_ct1.get_local_id(2) % 16; + const int batch_size_warps = (WARPS-1)*2; + + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + auto d_A = dacc_A.template get_multi_ptr(); + auto d_B = dacc_B.get_multi_ptr(); + auto sg_size = item_ct1.get_sub_group(); + + + sycl::ext::oneapi::experimental::matrix::joint_matrix a_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix b_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sg_size, c_frag, 0.0f); + + + + //wmma::fragment a_frag; + //wmma::fragment b_frag; + //wmma::fragment c_frag; + //wmma::fill_fragment(c_frag, 0.0f); + + for(int i = item_ct1.get_local_id(2); i < (8*32); i+=item_ct1.get_local_range(2)) + smem_C[i] = 0.0f; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + int ticktock = 0; + int idx = 0 + item_ct1.get_local_id(2); + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = dacc_A[idx]; + local_A[1] = dacc_A[idx+item_ct1.get_local_range(2)-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = dacc_B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = item_ct1.get_local_range(2)-32; base_idx < K; base_idx+=item_ct1.get_local_range(2)-32) + { + idx = base_idx + item_ct1.get_local_id(2); + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = dacc_A[idx]; + local_A[1] = dacc_A[idx+item_ct1.get_local_range(2)-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = dacc_B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = dacc_B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + + sycl::half local_absmax = absmax[absidx]; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + } + //printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + + dacc_A[(ticktock*batch_size_warps + k)*a_tile_offset] = smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]; + dacc_B[(ticktock*batch_size_warps + k)*b_tile_offset] = smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]; + d_A = dacc_A.template get_multi_ptr(); + d_B = dacc_B.get_multi_ptr(); + + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_A, 16); + + + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, b_frag, d_B, 16); + + + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); + + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = item_ct1.get_local_id(2) % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + + dacc_A[(ticktock*batch_size_warps + k)*a_tile_offset] = smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]; + dacc_B[(ticktock*batch_size_warps + k)*b_tile_offset] = smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]; + d_A = dacc_A.template get_multi_ptr(); + d_B = dacc_B.get_multi_ptr(); + + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_A, 16); + + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, b_frag, d_B, 16); + + + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); + + } + + // 129 mu + if(warp_id == (WARPS-1)) + + //wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); + sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sg_size, c_frag, d_A, 32, sycl::ext::oneapi::experimental::matrix::layout::row_major); + + //printnonzero(smem_C, 32, ""); + + if(col_offset + warp_lane < M) + // use smem_A itself + dacc_out[col_offset + warp_lane] = dacc_A[warp_lane]; +#endif +} + + + + +//=========================================4 bit gemm naive=============== + + +#define num_values_4bit 32 + +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype) +{ + + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + // datatype absmax + const int warp_idx = item_ct1.get_local_id(2) / 32; + const int warp_lane = item_ct1.get_local_id(2) % 32; + const int row_B = (THREADS/32)*item_ct1.get_group(2) + warp_idx; + const int num_values_8bit = num_values_4bit/2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit/4]; + T local_A[num_values_4bit/4]; + + T local_absmax = T(0.0f); + + for(int i = item_ct1.get_local_id(2); i < 16; i++) + quant_map[i] = T(dacc_datatype[i]); + item_ct1.barrier(sycl::access::fence_space::local_space); + + // A: [1, K] + // B: [N, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) + { + int inner_idx_halved = inner_idx/2; + int offset_B = ldb*row_B; + int absidx = ((2*offset_B)+inner_idx)/blocksize; + + local_absmax = dacc_absmax[absidx]; + + if(row_B < M) + { + if((inner_idx_halved + num_values_8bit) < (K/2)) + { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx_halved) + j < (K/2)) + local_B_4bit[j] = dacc_B[offset_B+inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for(int i = 0; i < 4; i++) + { + #pragma unroll + for(int k = 0; k < num_values_8bit/4; k++) + { + #if DPCT_COMPATIBILITY_TEMP >= 800 + local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); + #endif + } + + if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + if(inner_idx + (i*num_values_4bit/4) + k < K) + local_A[k] = dacc_A[inner_idx + k + (i*num_values_4bit/4)]; + else + local_A[k] = T(0.0f); + + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + { + #if DPCT_COMPATIBILITY_TEMP >= 800 + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } + } + } + + local_C = sycl::reduce_over_group(item_ct1.get_sub_group(), local_C, sycl::plus<>()); + + if(row_B < M && warp_lane == 0) + dacc_out[row_B] = T(local_C); + +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template SYCL_EXTERNAL void kfunc(float *A, float *B, float value, long n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kfunc(float *A, float *B, float value, long n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kfunc(float *A, float *B, float value, long n, + const sycl::nd_item<3> &item_ct1); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out);; +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + unsigned char *smem_B, + sycl::half *smem_C, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + unsigned char *smem_B, + sycl::half *smem_C, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + unsigned char *smem_B, + sycl::half *smem_C, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + unsigned char *smem_B, + sycl::half *smem_C, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *quant_map, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out, + const sycl_dacc_float &dacc_absmax, + const sycl_dacc_float &dacc_datatype); + +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, sycl::ext::oneapi::bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::ext::oneapi::bfloat16 *quant_map, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out, + const sycl_dacc_float &dacc_absmax, + const sycl_dacc_float &dacc_datatype); + +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + float *quant_map, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out, + const sycl_dacc_float &dacc_absmax, + const sycl_dacc_float &dacc_datatype); + + +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + + + +//==================supported template decls======================================================= + +template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, const sycl::accessor &dacc_bias); + + +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); + +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); + + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + + +template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); + +template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); + +template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row); +template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row); + +template unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +template SYCL_EXTERNAL void kEstimateQuantiles(float* __restrict__ const A, float *code, const float offset, const float max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); +template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half* __restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_unorm); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, sycl::half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, sycl::half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, sycl::half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) +MAKE_PreconditionOptimizer32bit1State(ADAM, sycl::half) +MAKE_PreconditionOptimizer32bit1State(ADAM, float) +MAKE_PreconditionOptimizer32bit1State(ADAM, sycl::ext::oneapi::bfloat16) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template SYCL_EXTERNAL void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_unorm); \ + +MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(RMSPROP, sycl::half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, sycl::half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, sycl::half) +MAKE_Optimizer32bit1State(ADAGRAD, float) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g, const sycl_dacc_float &dacc_unorm); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) +MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::ext::oneapi::bfloat16) + + +template SYCL_EXTERNAL void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm); + +template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::half* g, sycl::half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm); + +template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::ext::oneapi::bfloat16* g, sycl::ext::oneapi::bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm); + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template SYCL_EXTERNAL void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, \ + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1); \ + +MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, sycl::half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, sycl::half) +MAKE_PreconditionStatic8bit1State(LION, float) + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1,const sycl_la &tacc, \ + const sycl::accessor &dacc_g, \ + const sycl::accessor &dacc_p, \ + const sycl_dacc_uc &dacc_state1, \ + const sycl_dacc_float &dacc_quantiles1, \ + const sycl_dacc_float &dacc_absmax1); \ + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::ext::oneapi::bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::ext::oneapi::bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAM, sycl::ext::oneapi::bfloat16, 2048, 8) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2,const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, \ + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, \ + const sycl_dacc_float &dacc_new_max2); \ + +MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, const sycl_dacc_float &dacc_unorm, \ + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, \ + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, \ + const sycl_dacc_float &dacc_new_max2); \ + +MAKE_optimizerStatic8bit2State(ADAM, sycl::half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template SYCL_EXTERNAL void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, + float *dacc_gnorm_vec); +template SYCL_EXTERNAL void kPercentileClipping(sycl::half * __restrict__ g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, float *dacc_gnorm_vec); + + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, NF4) + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2,const sycl_la &tacc, \ + const sycl::accessor &dacc_g, \ + const sycl::accessor &dacc_p, \ + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, \ + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, \ + const sycl_dacc_float &dacc_absmax1, const sycl_dacc_float &dacc_absmax2); \ + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::ext::oneapi::bfloat16, 2048, 8) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h new file mode 100644 index 000000000..33fe46354 --- /dev/null +++ b/csrc/sycl/kernels.h @@ -0,0 +1,265 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include "ops.h" + +#ifndef kernels +#define kernels + +#pragma once + +//================typedefs=================================== + +typedef sycl::local_accessor sycl_la; +typedef sycl::accessor sycl_dacc; +typedef sycl::accessor sycl_dacc_float; +typedef sycl::accessor sycl_dacc_uc; +typedef sycl::accessor sycl_dacc_char; + +//=========================================================== +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); + +extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, + const sycl::nd_item<3> &item_ct1, float* smem_code, const sycl_la &tacc, const sycl_dacc_float &dacc_A, + const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code); + +extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code); + +template extern SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, + const sycl::nd_item<3> &item_ct1, + float *smem_code, + float *smem_absmax_value, + const sycl_la &tacc,const sycl::accessor &dacc_A, + const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, + const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +//=========================k-dequant blockwise ====================== +template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +//====================32 bit headers============================= + +template +extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g, const sycl_dacc_float &dacc_unorm); + +template +extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm); + +template +extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1, + const sycl_dacc_float &dacc_unorm); + +template +extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const + sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_unorm); + + +//==============================8 bit headers========================== +template +extern SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, + float *smem_quantiles1, + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1); + + +template +extern SYCL_EXTERNAL void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + const sycl_la &tacc, + const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1); + + + +template +extern SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, + float *smem_quantiles1, float *smem_quantiles2, + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, + const sycl_dacc_float &dacc_new_max2); + + +template +extern SYCL_EXTERNAL void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + float *smem_quantiles2,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, const sycl_dacc_float &dacc_unorm, + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, + const sycl_dacc_float &dacc_new_max2); + +//====================8 bit blockwise========================= + +template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + sycl::local_accessor smem_quantiles2, + float *smem_exchange1, float *smem_exchange2, + const sycl_la &tacc, const sycl::accessor &dacc_g, + const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_absmax1, const sycl_dacc_float &dacc_absmax2); + +template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + float *smem_exchange1, + const sycl_la &tacc, + const sycl::accessor &dacc_g, + const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_absmax1); + +//=======================percentile clipping============================ +template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, float *dacc_gnorm_vec); + + +//===============histogram======================== + +extern SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, + const sycl::nd_item<3> &item_ct1, const sycl_dacc_float &dacc_histogram, const sycl_dacc &dacc_index1, + const sycl_dacc &dacc_index2, const sycl_dacc_float &dacc_src); + +//====================spm======================= +template +extern SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, + int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, T *B, sycl::half *out, + float *__restrict__ const dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + +//=====================mm dequant ==================================== +template +extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, + float *__restrict__ const colStats, sycl::half *out, float *newRowStats, + float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, + const int numCols, const int tileCols, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A, + const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, + const sycl::accessor &dacc_bias +); +//==================k row col stats===================== + +template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, + const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row); + +//===========================double row col quant=================== +template extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + char *out_col_normed, char *out_row_normed, int *rowidx, + int *colidx, sycl::half *val, + int *__restrict__ nnz_block_ptr, float threshold, + int rows, int cols, int tiledCols, + const sycl::nd_item<3> &item_ct1, + float *smem_row_stats, unsigned int *smem_nnz_row_idx, + const sycl_la &tacc, const sycl::accessor &dacc_A, + const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, + const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); + +//==============================k transfrom row col===================== +template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, +const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +//========================k extract outliers========================= +template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); + +//=========================gemm device============================ +template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, T *smem_A, T *smem_B, const sycl::accessor &dacc_A, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out); + +//=========================gemm 4 bit inf================================ +template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, unsigned char *smem_B, + T *smem_C, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); + +//====================gemm 4 bit naive inf============================ +template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype); + +template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, + const sycl::nd_item<3> &item_ct1); + +#endif diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp new file mode 100644 index 000000000..d345224d0 --- /dev/null +++ b/csrc/sycl/ops.cpp @@ -0,0 +1,2047 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include "ops.h" +#include "kernels.h" +#include +#include +#include +#include +#include + +#include +#include + + +#define ERR_NOT_IMPLEMENTED 100 + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 +#define NUM_PER_THREAD 4 + +using namespace dnnl; + +typedef sycl::ext::oneapi::bfloat16 bf16; + +using namespace BinSearch; +using std::cout; +using std::endl; + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +//================================histogram 2d============================================== + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + + int threads = 512; + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + int size = NUM_BLOCK; + + + sycl::buffer buff_histogram(histogram,sycl::range<1>(size)); + sycl::buffer buff_index1(index1,sycl::range<1>(size)); + sycl::buffer buff_index2(index2,sycl::range<1>(size)); + sycl::buffer buff_src(src,sycl::range<1>(size)); + + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + sycl::accessor dacc_histogram(buff_histogram, cgh, sycl::read_write); + sycl::accessor dacc_index1(buff_index1, cgh, sycl::read_write); + sycl::accessor dacc_index2(buff_index2, cgh, sycl::read_write); + sycl::accessor dacc_src(buff_src, cgh, sycl::read_write); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kHistogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n, item_ct1, dacc_histogram, dacc_index1, dacc_index2, dacc_src); + }); + }); + } + +} +//============================estimate quantiles=============================== +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + std::memset(code, 0, 256*sizeof(float)); + //DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait()); + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_code(code,sycl::range<1>(size)); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + //using group_radix_sort = dpct::group::radix_sort; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(sycl::range<1>(temp_storage_size), cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + + + auto std_numeric_limits_T_max_ct3 = std::numeric_limits::max(); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kEstimateQuantiles(A, code, offset, std_numeric_limits_T_max_ct3, n, item_ct1, tacc, dacc_A, dacc_code); + + }); + }); + } + +} + +//============================k quantize =============================== +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_code(code,sycl::range<1>(size)); + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kQuantize(code, A, out, n, item_ct1, smem_code_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out, dacc_code); + }); + }); + } + +} + +//============================k dequantize=============================== +void dequantize(float *code, unsigned char *A, float *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + unsigned char *buff_A; + float *buff_out; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_out), (void*)(out), NUM_BLOCK); + q_ct1.memcpy((void*)(buff_A), (void*)(A), NUM_BLOCK); + + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kDequantize(code, buff_A, buff_out, n, item_ct1, smem_code_acc_ct1.get_pointer()); + }); + }); + + + } + //back memcpy + q_ct1.memcpy((void *)(out), (void*)(buff_out), NUM_BLOCK); + q_ct1.memcpy((void*)(A), (void*)(buff_A), NUM_BLOCK); + +} + +//============================quantize blockwise=============================== + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + sycl::context ctx = q_ct1.get_context(); + int size= NUM_BLOCK; + for(int i=0; i< NUM_BLOCK; i++){ out[i]=out[(DATA_TYPE > 0) ? i/2 : i];}; + + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_rand(rand,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_code(code,sycl::range<1>(size)); + sycl::buffer buff_absmax(absmax,sycl::range<1>(size)); + + + + if(blocksize == 4096) + + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + + + //__shared__ vars for funtions + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); + }); + }); + } + else if(blocksize == 2048) + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); + }); + }); + } + else if(blocksize == 1024) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + + //__shared__vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); + }); + }); + } + else if(blocksize == 512) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); + }); + }); + } + else if(blocksize == 256) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); + }); + }); + } + else if(blocksize == 128) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); + }); + }); + } + else if(blocksize == 64) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); + }); + }); + } + +} + + +//============================k dequantize blockwise=============================== +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_code(code,sycl::range<1>(size)); + sycl::buffer buff_absmax(absmax,sycl::range<1>(size)); + + + + + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh){ + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + if(DATA_TYPE > 0){ + kDequantizeBlockwise(code, A, absmax, out, blocksize/2, n, item_ct1, tacc, dacc_A, dacc_out, dacc_code, dacc_absmax); } + else{ + kDequantizeBlockwise(code, A, absmax, out, blocksize, n, item_ct1, tacc, dacc_A, dacc_out, dacc_code, dacc_absmax); + } + }); + + }); + +} + + +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(cudaPeekAtLastError()); +//} + + + +//============================32 bit optimizer=============================== +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) + try { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + int size= NUM_BLOCK; + + sycl::buffer buff_g(g,sycl::range<1>(size)); + sycl::buffer buff_p(p,sycl::range<1>(size)); + sycl::buffer buff_state1(state1,sycl::range<1>(size)); + sycl::buffer buff_state2(state2,sycl::range<1>(size)); + sycl::buffer buff_unorm(unorm, sycl::range<1>(size)); + + switch(OPTIMIZER) + { + case ADAM: + if(max_unorm > 0.0f) + { + std::memset(unorm, 0, 1*sizeof(float)); + //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizer32bit2State(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_state1, dacc_state2, dacc_g, dacc_unorm); + }); + }); + } + + } + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizer32bit2State(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1, dacc_state2, dacc_unorm); + }); + }); + } + + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + std::memset(unorm, 0, 1*sizeof(float)); + //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_g, dacc_state1, dacc_unorm); + }); + }); + } + } + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1, dacc_unorm); + }); + }); + } + + break; + case LION: + // in lion, the momentum update after the parameter update + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1, dacc_unorm); + }); + }); + } + + + if(max_unorm > 0.0f) + { + std::memset(unorm, 0, 1*sizeof(float)); + //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_g, dacc_state1, dacc_unorm); + }); + }); + } + + } + break; + } + +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + + + +//============================8 bit optimizer=============================== + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) + try { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + sycl::buffer buff_g(g,sycl::range<1>(size)); + sycl::buffer buff_p(p,sycl::range<1>(size)); + sycl::buffer buff_state1(state1,sycl::range<1>(size)); + sycl::buffer buff_state2(state2,sycl::range<1>(size)); + + sycl::buffer buff_quantiles1(quantiles1,sycl::range<1>(size)); + sycl::buffer buff_quantiles2(quantiles2,sycl::range<1>(size)); + sycl::buffer buff_max1(max1,sycl::range<1>(size)); + sycl::buffer buff_max2(max2,sycl::range<1>(size)); + sycl::buffer buff_new_max1(new_max1,sycl::range<1>(size)); + sycl::buffer buff_new_max2(new_max2,sycl::range<1>(size)); + sycl::buffer buff_unorm(unorm,sycl::range<1>(size)); + + + if(max_unorm > 0.0f){ + std::memset(unorm, 0, 1*sizeof(float)); } + + switch(OPTIMIZER) + { + case ADAM: + std::memset(new_max1, 0, 1*sizeof(float)); + std::memset(new_max2, 0, 1*sizeof(float)); + + //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); + //DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait()); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles2(buff_quantiles2, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_max2(buff_max2, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max2(buff_new_max2, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizerStatic8bit2State(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1, dacc_state2, dacc_unorm, dacc_quantiles1, dacc_quantiles2, dacc_max1, dacc_max2, dacc_new_max1 , dacc_new_max2); + }); + }); + } + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles2(buff_quantiles2, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_max2(buff_max2, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max2(buff_new_max2, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit2State(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_state2, dacc_unorm, dacc_quantiles1, dacc_quantiles2, dacc_max1, dacc_max2, dacc_new_max1 , dacc_new_max2); + }); + }); + } + + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + + std::memset(new_max1, 0, 1*sizeof(float)); + //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1, dacc_unorm, dacc_quantiles1, dacc_max1, dacc_new_max1); + }); + }); + } + + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1,smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_unorm, dacc_quantiles1, dacc_max1, dacc_new_max1); + }); + }); + } + + break; + case LION: + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_unorm, dacc_quantiles1, dacc_max1, dacc_new_max1); + }); + }); + } + std::memset(new_max1, 0, 1*sizeof(float)); + //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1, dacc_unorm, dacc_quantiles1, dacc_max1, dacc_new_max1); + }); + }); + } + + break; + default: + break; + } + +}catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + + +//============================8 bit blockwise optimizer=============================== + +#define BLOCKSIZE_2STATE 2048 +#define NUM_2STATE 8 +#define BLOCKSIZE_1STATE 2048 +#define NUM_1STATE 8 + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) + try { + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int num_blocks = 0; + int size = BLOCKSIZE_2STATE; + + sycl::buffer buff_g(g,sycl::range<1>(size)); + sycl::buffer buff_p(p,sycl::range<1>(size)); + sycl::buffer buff_state1(state1,sycl::range<1>(size)); + sycl::buffer buff_state2(state2,sycl::range<1>(size)); + sycl::buffer buff_quantiles1(quantiles1,sycl::range<1>(size)); + sycl::buffer buff_quantiles2(quantiles2,sycl::range<1>(size)); + sycl::buffer buff_absmax1(absmax1,sycl::range<1>(size)); + sycl::buffer buff_absmax2(absmax2,sycl::range<1>(size)); + + + switch(OPTIMIZER) + { + case ADAM: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles2(buff_quantiles2, cgh, sycl::read_write); + sycl::accessor dacc_absmax1(buff_absmax1, cgh, sycl::read_write); + sycl::accessor dacc_absmax2(buff_absmax2, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + sycl::local_accessor smem_exchange2_acc_ct1(sycl::range<1>(1), cgh); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE), sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit2StateBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n,item_ct1, smem_quantiles1_acc_ct1, smem_quantiles2_acc_ct1,smem_exchange1_acc_ct1.get_pointer(), smem_exchange2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_state2, dacc_quantiles1, dacc_quantiles2, dacc_absmax1, dacc_absmax2); + }); + }); + } + + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_absmax1(buff_absmax1, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit1StateBlockwise(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_quantiles1, dacc_absmax1); + }); + }); + } + + break; + } + +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +//============================percentile clipping=============================== + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + int size = NUM_BLOCK; + + sycl::buffer buff_g(g,sycl::range<1>(size)); + std::memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)); + sycl::buffer buff_gnorm_vec(gnorm_vec, sycl::range<1>(size)); + + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_gnorm_vec(buff_gnorm_vec, cgh, sycl::read_write); + + //sycl::local_accessor dacc_gnorm_vec(sycl::range<1>(size), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kPercentileClipping(g, gnorm_vec, step, n, item_ct1, tacc, dacc_g, dacc_gnorm_vec.get_pointer()); + }); + }); + } + +} + +//==========================dequant mm int 32 fp16========================== + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *bias, int numRows, int numCols) +{ + int threads = 512; + int tileCols = fill_up_to_nearest_multiple(numCols, 32); + int n = numRows*tileCols; + int subtile_rows = 128; + int tilesize = 32*subtile_rows; + int num_blocks = numRows/subtile_rows; + num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + num_blocks = num_blocks*(tileCols/32); + assert(threads <= tilesize); + int size = NUM_BLOCK; + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + + sycl::buffer buff_A (A, sycl::range<1>(size)); + sycl::buffer buff_rowStats (rowStats, sycl::range<1>(size)); + sycl::buffer buff_colStats (colStats, sycl::range<1>(size)); + sycl::buffer buff_out (out, sycl::range<1>(size)); + sycl::buffer buff_bias (bias, sycl::range<1>(size)); + + + + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + + sycl::accessor dacc_rowStats(buff_rowStats, cgh, sycl::read_write); + sycl::accessor dacc_colStats(buff_colStats, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_bias(buff_bias, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_rowStats_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), + [=](sycl::nd_item<3> item_ct1) { + kdequant_mm_int32_fp16<4, 128, 512>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n, item_ct1,smem_rowStats_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rowStats, dacc_colStats, dacc_out, dacc_bias); + }); + + }); + +} + +//========================GEMM============================ + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) + try { + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + int status; + + dpct::gemm(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, B, dpct::library_data_t::real_int8, ldb, beta, C, dpct::library_data_t::real_int32, ldc, dpct::library_data_t::real_int32); + + +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) + try { + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + int status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + dpct::gemm_batch(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, (long long int)strideA, B, dpct::library_data_t::real_int8, ldb, (long long int)strideB, beta, C, dpct::library_data_t::real_int32, ldc, (long long int)strideC, batchCount, dpct::library_data_t::real_int32); + +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; + } +} + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); + +//=================================transform GEMM============================== + +template void transform( T *A, T *out, int dim1, int dim2) +{ + + using namespace dnnl; + using tag = memory::format_tag; + using dt = memory::data_type; + void *Aout; + auto dev = sycl::device(sycl::gpu_selector_v); + auto ctx = sycl::context(dev); + int ldA = get_leading_dim(dim1, dim2); + int ldOut = get_leading_dim(dim1, dim2); + int ldAOut = get_leading_dim(dim1, dim2); + + dnnl::engine engine = sycl_interop::make_engine(dev, ctx); + // column major + const memory::dims a_strides = memory::dims {1, ldA}; + const auto a_md = DTYPE ==32 ? memory::desc({dim1, dim2}, dt::s32, a_strides) : memory::desc({dim1, dim2}, dt::s8, a_strides); + const memory::dims out_strides = memory::dims {ldOut, 1}; + const auto out_md = DTYPE ==32 ? memory::desc({dim1, dim2}, dt::s32, out_strides) : memory::desc({dim1, dim2}, dt::s8, out_strides); + const memory::dims aout_strides = memory::dims {ldAOut, 1}; + const auto aout_md = DTYPE == 32 ? memory::desc({dim1, dim2}, dt::s32, aout_strides) : memory::desc({dim1, dim2}, dt::s8, aout_strides); + + //memory align + memory a_mem(a_md, engine, A); + memory out_mem(out_md, engine, out); + memory aout_mem(aout_md, engine, Aout); + + //create dnnl stream + auto q_ct1 = sycl::queue(ctx, dev); + dnnl::stream stream = sycl_interop::make_stream(engine, q_ct1); + + primitive_attr attr; + + auto matmul_pd = matmul::primitive_desc(engine, a_md, out_md, aout_md, attr); + auto matmul_prim = matmul(matmul_pd); + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, a_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, out_mem}); + matmul_args.insert({DNNL_ARG_DST, aout_mem}); + + matmul_prim.execute(stream, matmul_args); + stream.wait(); + +} + + +template void transform(int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int32_t *A, int32_t *out, int dim1, int dim2); +template void transform( int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int32_t *A, int32_t *out, int dim1, int dim2); + + +//========================igemmlt============================================ + +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + try { + + using tag = memory::format_tag; + using dt = memory::data_type; + auto dev = sycl::device(sycl::gpu_selector_v); + auto ctx = sycl::context(dev); + + dnnl::engine engine = sycl_interop::make_engine(dev, ctx); + // column major + const memory::dims a_strides = memory::dims {1, lda}; + const auto a_md = memory::desc({m, k}, dt::s8, a_strides); + const memory::dims b_strides = memory::dims {ldb, 1}; + const auto b_md = memory::desc({k, n}, dt::s8, b_strides); + const memory::dims c_strides = memory::dims {ldc, 1}; + const auto c_md = DTYPE_OUT == 32 ? memory::desc({m, n}, dt::s32, c_strides) : memory::desc({m, n}, dt::s8, c_strides); + + //memory align + memory a_mem(a_md, engine); + memory b_mem(b_md, engine); + memory c_mem(c_md, engine); + memory scales_C_mem({{1}, dt::f32, {1}}, engine, row_scale); + + //create dnnl stream + auto q_ct1 = sycl::queue(ctx, dev); + dnnl::stream stream = sycl_interop::make_stream(engine, q_ct1); + + primitive_attr attr; + if (SCALE_ROWS) { + attr.set_scales_mask(DNNL_ARG_DST, /* mask */ 1 << 1); + } + + auto matmul_pd = matmul::primitive_desc(engine, a_md, b_md, c_md, attr); + auto matmul_prim = matmul(matmul_pd); + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, a_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, b_mem}); + matmul_args.insert({DNNL_ARG_DST, c_mem}); + + if (SCALE_ROWS) { + matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scales_C_mem}); + } + matmul_prim.execute(stream, matmul_args); + stream.wait(); + +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + + + +//===========================gemm_host============================================ + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + + sycl::buffer buff_A (A, sycl::range<1>(size)); + sycl::buffer buff_B (B, sycl::range<1>(size)); + sycl::buffer buff_out (out, sycl::range<1>(size)); + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); + sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 160), sycl::range<3>(1, 1, 160)), + [=](sycl::nd_item<3> item_ct1) { + gemm_device(m, n, k, A, B, out, lda, ldb, ldc, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), + dacc_A, dacc_B, dacc_out); + }); + }); + } + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + + +} + + +//============================gemm 4bit inference ================================ + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + + sycl::buffer buff_A (A, sycl::range<1>(size)); + sycl::buffer buff_B (B, sycl::range<1>(size)); + sycl::buffer buff_out (out, sycl::range<1>(size)); + + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(176/*8*16 + (16*(batch_size_warps-1))*/), cgh); + sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); + sycl::local_accessor smem_C_acc_ct1(sycl::range<1>(8*32), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 96), sycl::range<3>(1, 1, 96)), + [=](sycl::nd_item<3> item_ct1) { + kgemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), smem_C_acc_ct1.get_pointer(), dacc_A, dacc_B, dacc_out); + }); + }); + } + +} + + +//============================gemm 4 bit inference naive ================= + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+3)/4; + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + + sycl::buffer buff_A (A, sycl::range<1>(size)); + sycl::buffer buff_B (B, sycl::range<1>(size)); + sycl::buffer buff_out (out, sycl::range<1>(size)); + sycl::buffer buff_absmax(absmax, sycl::range<1>(size)); + sycl::buffer buff_datatype(datatype, sycl::range<1>(size)); + + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + sycl::accessor dacc_datatype(buff_datatype, cgh, sycl::read_write); + sycl::local_accessor quant_map_acc_ct1(sycl::range<1>(16), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + kgemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer(), dacc_A, dacc_B, dacc_out, dacc_absmax, dacc_datatype); + }); + }); + } + +} +//================================spm coo================================== + +void spmm_coo(int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) +{ + + try{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + + dpct::sparse::sparse_matrix_desc_t descA; + std::shared_ptr descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + + // Create dense matrix C + + descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + + descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); + // allocate an external buffer if needed + + bufferSize = 0; + + dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1); + + + dpct::sparse::spmm(q_ct1, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float); + // destroy matrix/vector descriptors + descA.reset(); + descB.reset(); + descC.reset(); + sycl::free(dBuffer, q_ct1); + + } + catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); + } + +} + +//===============================spm _coo _very _sparse========================= + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + sycl::buffer buff_max_count(max_count,sycl::range<1>(size)); + sycl::buffer buff_max_idx(max_idx,sycl::range<1>(size)); + sycl::buffer buff_offset_rowidx(offset_rowidx,sycl::range<1>(size)); + sycl::buffer buff_rowidx(rowidx,sycl::range<1>(size)); + sycl::buffer buff_colidx(colidx,sycl::range<1>(size)); + sycl::buffer buff_values(values,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_B(B, sycl::range<1>(size)); + sycl::buffer buff_dequant_stats(dequant_stats,sycl::range<1>(size)); + + + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + sycl::accessor dacc_max_count(buff_max_count, cgh, sycl::read_write); + sycl::accessor dacc_max_idx(buff_max_idx, cgh, sycl::read_write); + sycl::accessor dacc_offset_rowidx(buff_offset_rowidx, cgh, sycl::read_write); + sycl::accessor dacc_colidx(buff_colidx, cgh, sycl::read_write); + sycl::accessor dacc_rowidx(buff_rowidx, cgh, sycl::read_write); + sycl::accessor dacc_values(buff_values, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_dequant_stats(buff_dequant_stats, cgh, sycl::read_write); + sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); + + + //smem + sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nnz_rows) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kspmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB, item_ct1, smem_dequant_stats_acc_ct1.get_pointer(), dacc_max_count, dacc_max_idx, dacc_offset_rowidx, dacc_rowidx, dacc_colidx, dacc_values, dacc_B, dacc_out, dacc_dequant_stats); + }); + }); + } + +} + +//======================================non gemm 2d quants============================================ + +//===========================Row col stats================================= + +#define STATS_THREADS 64 +#define STATS_ITEMS 4 +#define STATS_ROWS 16 +void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) +{ + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + + int tile_cols = STATS_THREADS*STATS_ITEMS; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + int size = NUM_BLOCK; + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_nnz_count_row(nnz_count_row,sycl::range<1>(size)); + sycl::buffer buff_rowStats(rowStats,sycl::range<1>(size)); + sycl::buffer buff_colStats(colStats,sycl::range<1>(size)); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_rowStats(buff_rowStats, cgh, sycl::read_write); + sycl::accessor dacc_colStats(buff_colStats, cgh, sycl::read_write); + sycl::accessor dacc_nnz_count_row(buff_nnz_count_row, cgh, sycl::read_write); + //__shared__ vars + sycl::local_accessor smem_row_absmax_values_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_row_nnz_values_acc_ct1(sycl::range<1>(256), cgh); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + if(nnz_threshold == 0.0){ + kgetColRowStats(A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, + smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), tacc, + dacc_A, dacc_rowStats, dacc_colStats, dacc_nnz_count_row); + } + else if(nnz_threshold != 0.0){ + kgetColRowStats(A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, + smem_row_absmax_values_acc_ct1.get_pointer(),smem_row_nnz_values_acc_ct1.get_pointer(), + tacc, dacc_A, dacc_rowStats, dacc_colStats, dacc_nnz_count_row); + } + }); + }); + +} + + +//===================================double row col quant====================== + +void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int *nnz_block_ptr, float threshold, int rows, int cols) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int num_blocks = 0; + int size = NUM_BLOCK; + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out_col_normed(out_col_normed,sycl::range<1>(size)); + sycl::buffer buff_out_row_normed(out_row_normed,sycl::range<1>(size)); + + sycl::buffer buff_rowStats(rowStats,sycl::range<1>(size)); + sycl::buffer buff_colStats(colStats,sycl::range<1>(size)); + sycl::buffer buff_rowidx(rowidx,sycl::range<1>(size)); + sycl::buffer buff_colidx(colidx,sycl::range<1>(size)); + sycl::buffer buff_val(val,sycl::range<1>(size)); + sycl::buffer buff_nnz_block_ptr(nnz_block_ptr,sycl::range<1>(size)); + + + + int threads = 64; + int items_per_thread = 4; + int tile_cols = threads*items_per_thread; + int tile_rows = 16; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + num_blocks = row_tiles * col_tiles; + + + if(threshold > 0.0f) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out_col_normed(buff_out_col_normed, cgh, sycl::read_write); + sycl::accessor dacc_out_row_normed(buff_out_row_normed, cgh, sycl::read_write); + + sycl::accessor dacc_rowStats(buff_rowStats, cgh, sycl::read_write); + sycl::accessor dacc_colStats(buff_colStats, cgh, sycl::read_write); + sycl::accessor dacc_rowidx(buff_rowidx, cgh, sycl::read_write); + sycl::accessor dacc_colidx(buff_colidx, cgh, sycl::read_write); + sycl::accessor dacc_val(buff_val, cgh, sycl::read_write); + sycl::accessor dacc_nnz_block_ptr(buff_nnz_block_ptr, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_row_stats_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_nnz_row_idx_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out_col_normed, dacc_out_row_normed, dacc_rowStats, dacc_colStats, dacc_rowidx, dacc_colidx, dacc_val, dacc_nnz_block_ptr); + }); + }); + } + else + { + + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out_col_normed(buff_out_col_normed, cgh, sycl::read_write); + sycl::accessor dacc_out_row_normed(buff_out_row_normed, cgh, sycl::read_write); + + sycl::accessor dacc_rowStats(buff_rowStats, cgh, sycl::read_write); + sycl::accessor dacc_colStats(buff_colStats, cgh, sycl::read_write); + sycl::accessor dacc_rowidx(buff_rowidx, cgh, sycl::read_write); + sycl::accessor dacc_colidx(buff_colidx, cgh, sycl::read_write); + sycl::accessor dacc_val(buff_val, cgh, sycl::read_write); + sycl::accessor dacc_nnz_block_ptr(buff_nnz_block_ptr, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_row_stats_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_nnz_row_idx_acc_ct1(sycl::range<1>(256), cgh); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols,item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out_col_normed, dacc_out_row_normed, dacc_rowStats, dacc_colStats, dacc_rowidx, dacc_colidx, dacc_val, dacc_nnz_block_ptr); + }); + }); + + } + +} +//========================== transform row to format================================ +template void transformRowToFormat(char * A, char *out, int rows, int cols) +{ + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int num_blocks = 0; + int size = NUM_BLOCK; + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + + + int threads = 256; + int items_per_thread = 8; + // we load 128 column values per warp + int tile_cols = 32*items_per_thread; + int tile_rows = 32; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + num_blocks = row_tiles * col_tiles; + + int outCols = fill_up_to_nearest_multiple(cols, 32); + int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 8); + else + outRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 32); + else + outRows = fill_up_to_nearest_multiple(rows, 32); + } + else + { + if(TRANSPOSE) + { + outCols = fill_up_to_nearest_multiple(rows, 32); + outRows = cols; + } + } + + + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_data_acc_ct1(sycl::range<1>(32*33*8), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), + [=](sycl::nd_item<3> item_ct1) { + kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>(A, out, rows, cols, tiledCols, outRows, outCols, item_ct1, smem_data_acc_ct1.get_pointer(), dacc_A, dacc_out); + }); + }); + +} + +//===========================extract outliers=========================== + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) +{ + int threads = 512; + // we load 128 column values per warp + int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); + int tiledRows = 0; + int size = NUM_BLOCK; + int num_blocks = idx_size; + + if(FORMAT == COL_TURING) + { + tiledRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + tiledRows = fill_up_to_nearest_multiple(rows, 32); + } + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_idx(idx,sycl::range<1>(size)); + + + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_idx(buff_idx, cgh, sycl::read_write); + + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), + [=](sycl::nd_item<3> item_ct1) { + kExtractOutliers(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols, item_ct1, dacc_A, dacc_out, dacc_idx); + }); + }); + +} + +//==================================func=========================== + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kfunc(A, B, value, n, item_ct1); + }); + }); + +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, float *datatype, sycl::half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, bf16 * A, unsigned char* B, float *absmax, float *datatype, bf16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, sycl::half * A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, int bits); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +template void estimateQuantiles(sycl::half *A, float *code, float offset, int n); +template void estimateQuantiles(float *A, float *code, float offset, int n); + +template void quantizeBlockwise(float * code, sycl::half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, sycl::half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, sycl::half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, sycl::half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bf16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bf16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bf16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bf16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bf16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bf16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bf16 *out, int blocksize, const int n); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, sycl::half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, bf16) +MAKE_optimizer32bit(MOMENTUM, sycl::half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(RMSPROP, sycl::half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(LION, sycl::half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, bf16) +MAKE_optimizer32bit(ADAGRAD, sycl::half) +MAKE_optimizer32bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, sycl::half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, sycl::half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, sycl::half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, sycl::half) +MAKE_optimizerStatic8bit(LION, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(sycl::half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(sycl::half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(sycl::half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(sycl::half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(bf16, LION); +MAKE_optimizerStatic8bitBlockwise(sycl::half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(sycl::half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(bf16, ADAM); diff --git a/csrc/sycl/ops.h b/csrc/sycl/ops.h new file mode 100644 index 000000000..19887f028 --- /dev/null +++ b/csrc/sycl/ops.h @@ -0,0 +1,187 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + + +#define THREADS_PER_BLOCKS (512) + + +typedef enum Operations_t +{ + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t +{ + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, +} Optimizer_t; + +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context +{ + public: + dpct::queue_ptr m_handle; + + Context() + { + dpct::queue_ptr handle; + handle = &dpct::get_default_queue(); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + dpct::queue_ptr m_handle; + + ContextLt() + { + dpct::queue_ptr handle; + handle = &dpct::get_default_queue(); + m_handle = handle; + } + +}; + +class ContextCusparse +{ + public: + sycl::queue *m_handle; + + ContextCusparse() + { + sycl::queue *handle; + handle = &dpct::get_default_queue(); + m_handle = handle; + } + +}; + +template void estimateQuantiles(T *A, float *code, float offset, int n); + +void quantize(float *code, float *A, unsigned char *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, float weight_decay, + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); + +template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n); + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); + +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(dpct::queue_ptr ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transform(dpct::queue_ptr ltHandle, T *A, T *out, int dim1, int dim2); +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, + sycl::half *out, float *newRowStats, + float *newcolStats, sycl::half *bias, int numRows, + int numCols); +void getColRowStats(sycl::half *A, float *rowStats, float *colStats, + int *nnz_count_row, float nnz_threshold, int rows, + int cols); +void doubleRowColQuant(sycl::half *A, float *rowStats, float *colStats, + char *out_col_normed, char *out_row_normed, int *rowidx, + int *colidx, sycl::half *val, int *nnz_block_ptr, + float threshold, int rows, int cols); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +void spmm_coo(sycl::queue *handle, int *A_rowidx, int *A_colidx, + sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, sycl::half *B, int ldc, sycl::half *C, + bool transposed_B); + +template +void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, + int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, T *B, sycl::half *out, + float *dequant_stats, int nnz_rows, int nnz, + int rowsA, int rowsB, int colsB); + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +void matmul4bite(sycl::half *A, unsigned char *B, sycl::half *out, int lda, + int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template void func(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/sycl/utilities.h b/csrc/sycl/utilities.h new file mode 100644 index 000000000..0e8bcd0d6 --- /dev/null +++ b/csrc/sycl/utilities.h @@ -0,0 +1,138 @@ + +#include +#include +#include +#include +#include +#include "ops.h" +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 +#define NUM_PER_THREAD 4 + +typedef sycl::ext::oneapi::bfloat16 bf16; + +using std::cout; +using std::endl; + + + +namespace dpct{ +namespace group{ +enum store_algorithm { + + BLOCK_STORE_DIRECT, + BLOCK_STORE_STRIPED, + // To-do: BLOCK_STORE_WARP_TRANSPOSE + // To-do: BLOCK_STORE_VECTORIZE + +}; + +/// Stores a blocked arrangement of work items linear segment of items. +template +__dpct_inline__ void store_blocked(const Item &item, OutputIteratorT block_itr, + InputT (&items)[ITEMS_PER_WORK_ITEM]) { + + // This implementation does not take in account range storage across + // workgroup items To-do: Decide whether range storage is required for group + // storage + size_t linear_tid = item.get_local_linear_id(); + OutputIteratorT workitem_itr = block_itr + (linear_tid * ITEMS_PER_WORK_ITEM); +#pragma unroll + for (uint32_t idx = 0; idx < ITEMS_PER_WORK_ITEM; idx++) { + workitem_itr[idx] = items[idx]; + } +} + +/// Stores a striped arrangement of work items linear segment of items. +template +__dpct_inline__ void store_striped(const Item &item, OutputIteratorT block_itr, + InputT (&items)[ITEMS_PER_WORK_ITEM]) { + + // This implementation does not take in account range storage across + // workgroup items To-do: Decide whether range storage is required for group + // storage + size_t linear_tid = item.get_local_linear_id(); + OutputIteratorT workitem_itr = block_itr + linear_tid; + size_t GROUP_WORK_ITEMS = item.get_global_range().size(); +#pragma unroll + for (uint32_t idx = 0; idx < ITEMS_PER_WORK_ITEM; idx++) { + workitem_itr[(idx * GROUP_WORK_ITEMS)] = items[idx]; + } +} + +/// Stores a warp-striped arrangement of work items linear segment of items. +// Created as free function until exchange mechanism is +// implemented. +// To-do: inline this function with BLOCK_STORE_WARP_TRANSPOSE mechanism +template +__dpct_inline__ void +store_subgroup_striped(const Item &item, OutputIteratorT block_itr, + InputT (&items)[ITEMS_PER_WORK_ITEM]) { + + // This implementation does not take in account range loading across + // workgroup items To-do: Decide whether range loading is required for group + // loading + // This implementation uses unintialized memory for loading linear segments + // into warp striped arrangement. + uint32_t subgroup_offset = item.get_sub_group().get_local_linear_id(); + uint32_t subgroup_size = item.get_sub_group().get_local_linear_range(); + uint32_t subgroup_idx = item.get_sub_group().get_group_linear_id(); + uint32_t initial_offset = + (subgroup_idx * ITEMS_PER_WORK_ITEM * subgroup_size) + subgroup_offset; + OutputIteratorT workitem_itr = block_itr + initial_offset; +#pragma unroll + for (uint32_t idx = 0; idx < ITEMS_PER_WORK_ITEM; idx++) { + workitem_itr[(idx * subgroup_size)] = items[idx]; + } +} + +// template parameters : +// ITEMS_PER_WORK_ITEM: size_t variable controlling the number of items per +// thread/work_item +// ALGORITHM: store_algorithm variable controlling the type of store operation. +// InputT: type for input sequence. +// OutputIteratorT: output iterator type +// Item : typename parameter resembling sycl::nd_item<3> . +template +class workgroup_store { +public: + static size_t get_local_memory_size(size_t group_work_items) { return 0; } + workgroup_store(uint8_t *local_memory) : _local_memory(local_memory) {} + + __dpct_inline__ void store(const Item &item, OutputIteratorT block_itr, + InputT (&items)[ITEMS_PER_WORK_ITEM]) { + + if constexpr (ALGORITHM == BLOCK_STORE_DIRECT) { + store_blocked(item, block_itr, (&items)[ITEMS_PER_WORK_ITEM]); + } else if constexpr (ALGORITHM == BLOCK_STORE_STRIPED) { + store_striped(item, block_itr, (&items)[ITEMS_PER_WORK_ITEM]); + } + } + +private: + uint8_t *_local_memory; + +}; +} +} + + +