diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 42fcd4ad0d536..e7ce2cc8b2b83 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -291,6 +291,7 @@ else() ${MLAS_SRC_DIR}/dgemm.cpp ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) if(HAS_POWER10) set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") @@ -318,8 +319,10 @@ else() endif() set(mlas_platform_srcs_power10 ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp ) - set_source_files_properties(${mlas_platform_srcs_power10} PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") + set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") set(mlas_platform_srcs ${mlas_platform_srcs} ${mlas_platform_srcs_power10} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 82b1f5c978b18..e5ec4486184f6 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -499,6 +499,7 @@ extern "C" { MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernel; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelPOWER10; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernel; + MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10; #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; @@ -1886,7 +1887,7 @@ MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) #if defined(MLAS_SSE2_INTRINSICS) _mm_store_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) - vec_st(Vector, 0, Buffer); + *((MLAS_FLOAT64X2*)Buffer) = Vector; #endif } diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 5c92ce915fe4e..de7fee8c07aa9 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -379,6 +379,7 @@ Return Value: bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); if (HasP10Instructions) { this->GemmFloatKernel = MlasSgemmKernelPOWER10; + this->GemmDoubleKernel = MlasDgemmKernelPOWER10; } #endif #endif diff --git a/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp b/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp new file mode 100644 index 0000000000000..11638bc33f256 --- /dev/null +++ b/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp @@ -0,0 +1,418 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelPower.cpp + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + +--*/ + +#include "DgemmKernelpower.h" +struct MlasDgemmBroadcastAElementsMMA +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + double ARow[RowCount], + const double* A, + size_t lda + ) + { + ARow[Row] = A [Row * lda]; + } +}; + +template +MLAS_FORCEINLINE +void +MlasDgemmComputeAElements( + MLAS_FLOAT64X2 AElements[RowCount], + MLAS_FLOAT64X2 ABroadcast[RowCount] + ) +{ + ABroadcast[0] = vec_mergee (AElements[0], AElements[1]); + ABroadcast[1] = vec_mergee (AElements[2], AElements[3]); + ABroadcast[2] = vec_mergeo (AElements[0], AElements[1]); + ABroadcast[3] = vec_mergeo (AElements[2], AElements[3]); +} + +template +MLAS_FORCEINLINE +void +MlasDgemmComputeBlockMMA( + __vector_quad acc[8], + MLAS_FLOAT64X2 ABroadcast[RowCount], + MLAS_FLOAT64X2 A2Broadcast[RowCount], + const double* B, + size_t CountM + ) +{ + MLAS_FLOAT64X2 BElements[4]; + typedef __vector unsigned char vec_t; + __vector_pair A2pair, Apair; +#if (defined(__GNUC__) && (__GNUC__ == 10 && __GNUC_MINOR__ <= 2)) + __builtin_mma_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]); + if (CountM == 8) { + __builtin_mma_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]); + } +#elif (defined(__GNUC__) && (__GNUC__ == 11 && __GNUC_MINOR__ <= 2)) + Apair = *((__vector_pair *)((void *)&ABroadcast[0])); + if (CountM == 8) { + A2pair = *((__vector_pair *)((void *)&A2Broadcast[0])); + } +#else + __builtin_vsx_assemble_pair (&Apair, (vec_t)ABroadcast[1], (vec_t)ABroadcast[0]); + if (CountM == 8) { + __builtin_vsx_assemble_pair (&A2pair, (vec_t)A2Broadcast[1], (vec_t)A2Broadcast[0]); + } +#endif + BElements[0] = MlasLoadFloat64x2(B); + BElements[1] = MlasLoadFloat64x2(B + 2); + BElements[2] = MlasLoadFloat64x2(B + 4); + BElements[3] = MlasLoadFloat64x2(B + 6); + __builtin_mma_xvf64gerpp (&acc[0], Apair, (vec_t)BElements[0]); + __builtin_mma_xvf64gerpp (&acc[1], Apair, (vec_t)BElements[1]); + __builtin_mma_xvf64gerpp (&acc[2], Apair, (vec_t)BElements[2]); + __builtin_mma_xvf64gerpp (&acc[3], Apair, (vec_t)BElements[3]); + if (CountM == 8) { + __builtin_mma_xvf64gerpp (&acc[4], A2pair, (vec_t)BElements[0]); + __builtin_mma_xvf64gerpp (&acc[5], A2pair, (vec_t)BElements[1]); + __builtin_mma_xvf64gerpp (&acc[6], A2pair, (vec_t)BElements[2]); + __builtin_mma_xvf64gerpp (&acc[7], A2pair, (vec_t)BElements[3]); + } +} +template +struct MlasDgemmStoreVectorMMA +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT64X2 Result[4], + double* C, + size_t ldc, + MLAS_FLOAT64X2 AlphaBroadcast, + bool ZeroMode + ) + { + MLAS_FLOAT64X2 *rowC; + if (ZeroMode) { + rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount]; + rowC[0] = Result[Row] * AlphaBroadcast; + } else { + rowC = (MLAS_FLOAT64X2 *) &C[Row * ldc + VectorCount]; + rowC[0] += Result[Row] * AlphaBroadcast; + } + } +}; + +struct MlasDgemmMultiplyAlphaTrailingMMA +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT64X2 Accumulators[RowCount], + MLAS_FLOAT64X2 AlphaBroadcast + ) + { + Accumulators[Row] = MlasMultiplyFloat64x2(Accumulators[Row], AlphaBroadcast); + } +}; +template +struct MlasDgemmStoreScalarMMA +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT64X2 Accumulators[RowCount], + double* C, + size_t ldc, + bool ZeroMode + ) + { + double* c = C + Row * ldc + Lane; + double Value = Accumulators[Row][Lane]; + if (!ZeroMode) { + Value += *c; + } + + *c = Value; + } +}; + +template +MLAS_FORCEINLINE +size_t +MlasDgemmMMAProcessCount( + const double* A, + const double* B, + double* C, + size_t CountM, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + MLAS_FLOAT64X2 AlphaBroadcast, + bool ZeroMode + ) +{ + do { + + const double* a = A; + size_t k = CountK; + + MLAS_FLOAT64X2 Accumulators[2][RowCount] = {{ 0 }}; + MLAS_FLOAT64X2 Result[RowCount]; + MLAS_FLOAT64X2 AElements[RowCount]; + MLAS_FLOAT64X2 ABroadcast[RowCount] = { 0 }; + MLAS_FLOAT64X2 A2Broadcast[RowCount] = { 0 }; + MLAS_FLOAT64X2 A3Broadcast[RowCount] = { 0 }; + MLAS_FLOAT64X2 A4Broadcast[RowCount] = { 0 }; + double ARow[RowCount] = { 0 }; + double A2Row[RowCount] = { 0 }; + __vector_quad acc[8]; + + // + // Clear the block accumulators. + // + __builtin_mma_xxsetaccz(&acc[0]); + __builtin_mma_xxsetaccz(&acc[1]); + __builtin_mma_xxsetaccz(&acc[2]); + __builtin_mma_xxsetaccz(&acc[3]); + __builtin_mma_xxsetaccz(&acc[4]); + __builtin_mma_xxsetaccz(&acc[5]); + __builtin_mma_xxsetaccz(&acc[6]); + __builtin_mma_xxsetaccz(&acc[7]); + + // + // Compute the output block. + // + while (k >= 4) { + + MlasLoopUnroll()(AElements, a, lda); + MlasDgemmComputeAElements(AElements, ABroadcast); + MlasLoopUnroll()(AElements, a+2, lda); + MlasDgemmComputeAElements(AElements, A3Broadcast); + if (CountM == 8) { + MlasLoopUnroll()(AElements, a + ( lda * 4), lda); + MlasDgemmComputeAElements(AElements, A2Broadcast); + MlasLoopUnroll()(AElements, (a+2) + ( lda * 4), lda); + MlasDgemmComputeAElements(AElements, A4Broadcast); + } + MlasDgemmComputeBlockMMA(&acc[0], &ABroadcast[0], &A2Broadcast[0], B, CountM); + MlasDgemmComputeBlockMMA(&acc[0], &ABroadcast[2], &A2Broadcast[2], B+8, CountM); + MlasDgemmComputeBlockMMA(&acc[0], &A3Broadcast[0], &A4Broadcast[0], B+16, CountM); + MlasDgemmComputeBlockMMA(&acc[0], &A3Broadcast[2], &A4Broadcast[2], B+24, CountM); + B += 8 * 4; + a += 4; + k -= 4; + } + while (k > 0) { + MlasLoopUnroll()(ARow, a, lda); + if (CountM == 8) { + MlasLoopUnroll()(A2Row, a + (lda * 4), lda); + } + + MlasDgemmComputeBlockMMA(&acc[0], (MLAS_FLOAT64X2 *)ARow, (MLAS_FLOAT64X2 *)A2Row, B, CountM); + a += 1; + B += 8; + k -= 1; + } + if (CountN >= 8) { + + // + // Store the entire output block. + // + __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[2]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[3]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[6]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[7]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + } + } else { + + // + // Store the partial output block. + // + + if (CountN >= 6) { + __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[2]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[6]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 6 > 0) { + __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[7]); + } + } + if (CountN - 6 > 0) { + __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[3]); + } + } else if (CountN >= 4) { + __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[1]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + __builtin_mma_disassemble_acc ((void *)Result, &acc[5]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 4 > 0) { + __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[6]); + } + } + if (CountN - 4 > 0) { + __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[2]); + } + } else if (CountN >= 2) { + __builtin_mma_disassemble_acc ((void *)Result, &acc[0]); + MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + __builtin_mma_disassemble_acc ((void *)Result, &acc[4]); + MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 2 > 0) { + __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[5]); + } + } + if (CountN - 2 > 0) { + __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[1]); + } + } else { + __builtin_mma_disassemble_acc ((void *)Accumulators[0], &acc[0]); + if (CountM == 8) { + __builtin_mma_disassemble_acc ((void *)Accumulators[1], &acc[4]); + } + } + + // + // Store the remaining unaligned columns. + // + C += (CountN & ~1); + CountN &= 1; + + if (CountN > 0) { + + MlasLoopUnroll()(Accumulators[0], AlphaBroadcast); + MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll()(Accumulators[1], AlphaBroadcast); + MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); + } + } + + break; + } + + C += 8; + CountN -= 8; + + } while (CountN > 0); + + return CountM; +} + +size_t +MLASCALL +MlasDgemmKernelPOWER10( + const double* A, + const double* B, + double* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + double alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasDgemmCopyPackB or MlasDgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar multiplier (see DGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + size_t RowsHandled; + MLAS_FLOAT64X2 AlphaBroadcast = MlasBroadcastFloat64x2(alpha); + if (CountM >= 8) { + RowsHandled = MlasDgemmMMAProcessCount<4>(A, B, C, 8 ,CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 4) { + RowsHandled = MlasDgemmMMAProcessCount<4>(A, B, C, 4, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 2) { + RowsHandled = MlasDgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else { + RowsHandled = MlasDgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } + + return RowsHandled; +} diff --git a/onnxruntime/core/mlas/lib/power/DgemmKernelpower.h b/onnxruntime/core/mlas/lib/power/DgemmKernelpower.h index a7f780a22d01f..0dca7e4e43961 100644 --- a/onnxruntime/core/mlas/lib/power/DgemmKernelpower.h +++ b/onnxruntime/core/mlas/lib/power/DgemmKernelpower.h @@ -6,293 +6,16 @@ Licensed under the MIT License. Module Name: - DgemmKernelPower.cpp + DgemmKernelpower.h Abstract: - This module implements the kernels for the single precision matrix/matrix + This module implements the kernels for the double precision matrix/matrix multiply operation (DGEMM). --*/ -#include "mlasi.h" - -// -// Templates to ensure that a loop is unrolled. -// - -template -struct MlasLoopUnrollStep -{ - template - MLAS_FORCEINLINE - static - void - Step( - IterationArgs&&... Arguments - ) - { - IterationType::template Iteration(Arguments...); - MlasLoopUnrollStep::template Step(Arguments...); - } -}; - -template -struct MlasLoopUnrollStep -{ - template - MLAS_FORCEINLINE - static - void - Step( - IterationArgs&&... - ) - { - // Terminate the loop. - } -}; - -template -struct MlasLoopUnroll -{ - template - MLAS_FORCEINLINE - void - operator()( - IterationArgs&&... Arguments - ) - { - MlasLoopUnrollStep::template Step(Arguments...); - } -}; - -// -// Templates used with loop unrolling to perform an action on one row of the -// output. -// - -struct MlasDgemmZeroAccumulators -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[RowCount][4] - ) - { - Accumulators[Row][0] = MlasZeroFloat64x2(); - Accumulators[Row][1] = MlasZeroFloat64x2(); - Accumulators[Row][2] = MlasZeroFloat64x2(); - Accumulators[Row][3] = MlasZeroFloat64x2(); - } -}; - -struct MlasDgemmLoadAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 AElements[RowCount], - const double* A, - size_t lda - ) - { - AElements[Row] = MlasLoadFloat64x2(A + Row * lda); - } -}; - -struct MlasDgemmBroadcastAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 ABroadcast[RowCount], - const double* A, - size_t lda - ) - { - ABroadcast[Row] = MlasBroadcastFloat64x2(A + Row * lda); - } -}; - -template -struct MlasDgemmSplatAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 AElements[RowCount], - MLAS_FLOAT64X2 ABroadcast[RowCount] - ) - { - ABroadcast[Row] = vec_splat(AElements[Row], Lane); - } -}; - -struct MlasDgemmMultiplyAddRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[RowCount][4], - MLAS_FLOAT64X2 ABroadcast[RowCount], - MLAS_FLOAT64X2 BElements[4] - ) - { - Accumulators[Row][0] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[0], Accumulators[Row][0]); - Accumulators[Row][1] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[1], Accumulators[Row][1]); - Accumulators[Row][2] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[2], Accumulators[Row][2]); - Accumulators[Row][3] = MlasMultiplyAddFloat64x2(ABroadcast[Row], BElements[3], Accumulators[Row][3]); - } -}; - -template -MLAS_FORCEINLINE -void -MlasDgemmComputeBlock( - MLAS_FLOAT64X2 Accumulators[RowCount][4], - MLAS_FLOAT64X2 ABroadcast[RowCount], - const double* B - ) -{ - MLAS_FLOAT64X2 BElements[4]; - - BElements[0] = MlasLoadFloat64x2(B); - BElements[1] = MlasLoadFloat64x2(B + 2); - BElements[2] = MlasLoadFloat64x2(B + 4); - BElements[3] = MlasLoadFloat64x2(B + 6); - - MlasLoopUnroll()(Accumulators, ABroadcast, BElements); -} - -struct MlasDgemmMultiplyAlphaRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[4], - MLAS_FLOAT64X2 AlphaBroadcast - ) - { - Accumulators[Index] = MlasMultiplyFloat64x2(Accumulators[Index], AlphaBroadcast); - } -}; - -struct MlasDgemmMultiplyAlphaAddRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[4], - MLAS_FLOAT64X2 AlphaBroadcast, - const double* C - ) - { - Accumulators[Index] = MlasMultiplyAddFloat64x2(Accumulators[Index], - AlphaBroadcast, MlasLoadFloat64x2(C + Index * 2)); - } -}; - -struct MlasDgemmStoreRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[4], - double* C - ) - { - MlasStoreFloat64x2(C + Index * 2, Accumulators[Index]); - } -}; - -template -struct MlasDgemmStoreVector -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[RowCount][4], - double* C, - size_t ldc, - MLAS_FLOAT64X2 AlphaBroadcast, - bool ZeroMode - ) - { - double* c = C + Row * ldc; - if (ZeroMode) { - MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast); - } else { - MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast, c); - } - MlasLoopUnroll()(Accumulators[Row], c); - - // - // Shift down any unaligned elements to the bottom for further processing. - // - - if (VectorCount < 4) { - Accumulators[Row][0] = Accumulators[Row][VectorCount]; - } - } -}; - -struct MlasDgemmMultiplyAlphaTrailing -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[RowCount][4], - MLAS_FLOAT64X2 AlphaBroadcast - ) - { - Accumulators[Row][0] = MlasMultiplyFloat64x2(Accumulators[Row][0], AlphaBroadcast); - } -}; - -template -struct MlasDgemmStoreScalar -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[RowCount][4], - double* C, - size_t ldc, - bool ZeroMode - ) - { - double* c = C + Row * ldc + Lane; - double Value = MlasExtractLaneFloat64x2(Accumulators[Row][0]); - - if (!ZeroMode) { - Value += *c; - } - - *c = Value; - } -}; +#include "FgemmKernelpower.h" template MLAS_FORCEINLINE @@ -322,20 +45,20 @@ MlasDgemmProcessCount( // Clear the block accumulators. // - MlasLoopUnroll()(Accumulators); + MlasLoopUnroll()(Accumulators); // // Compute the output block. // while (k >= 2) { - MlasLoopUnroll()(AElements, a, lda); + MlasLoopUnroll()(AElements, a, lda); - MlasLoopUnroll>()(AElements, ABroadcast); - MlasDgemmComputeBlock(Accumulators, ABroadcast, B); + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); - MlasLoopUnroll>()(AElements, ABroadcast); - MlasDgemmComputeBlock(Accumulators, ABroadcast, B + 8); + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 8); a += 2; B += 8 * 2; @@ -343,8 +66,8 @@ MlasDgemmProcessCount( } if (k > 0) { - MlasLoopUnroll()(ABroadcast, a, lda); - MlasDgemmComputeBlock(Accumulators, ABroadcast, B); + MlasLoopUnroll()(ABroadcast, a, lda); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); a += 1; B += 8; @@ -357,7 +80,7 @@ MlasDgemmProcessCount( // Store the entire output block. // - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); } else { @@ -367,11 +90,11 @@ MlasDgemmProcessCount( // if (CountN >= 6) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); } else if (CountN >= 4) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); } else if (CountN >= 2) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); } // // Store the remaining unaligned columns. @@ -381,9 +104,9 @@ MlasDgemmProcessCount( if (CountN > 0) { - MlasLoopUnroll()(Accumulators, AlphaBroadcast); + MlasLoopUnroll()(Accumulators, AlphaBroadcast); - MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); } break; diff --git a/onnxruntime/core/mlas/lib/power/FgemmKernelpower.h b/onnxruntime/core/mlas/lib/power/FgemmKernelpower.h new file mode 100644 index 0000000000000..3746dbc82b3f6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/power/FgemmKernelpower.h @@ -0,0 +1,333 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelPower.h + +Abstract: + + This module implements the kernels for the single/double precision matrix/matrix + multiply operation (DGEMM/SGEMM). + +--*/ + +#include "mlasi.h" +#if defined(SINGLE) +#define MLAS_FLOATTYPE MLAS_FLOAT32X4 +#define MLAS_GEMMTYPE float +#define MLAS_LOAD_FLOAT MlasLoadFloat32x4 +#define MLAS_ZERO_FLOAT MlasZeroFloat32x4 +#define MLAS_STORE_FLOAT MlasStoreFloat32x4 +#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat32x4 +#define MLAS_MUL_FLOAT MlasMultiplyFloat32x4 +#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat32x4 +#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat32x4 +#else +#define MLAS_FLOATTYPE MLAS_FLOAT64X2 +#define MLAS_GEMMTYPE double +#define MLAS_LOAD_FLOAT MlasLoadFloat64x2 +#define MLAS_ZERO_FLOAT MlasZeroFloat64x2 +#define MLAS_STORE_FLOAT MlasStoreFloat64x2 +#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat64x2 +#define MLAS_MUL_FLOAT MlasMultiplyFloat64x2 +#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat64x2 +#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat64x2 +#endif +// +// Templates to ensure that a loop is unrolled. +// + +template +struct MlasLoopUnrollStep +{ + template + MLAS_FORCEINLINE + static + void + Step( + IterationArgs&&... Arguments + ) + { + IterationType::template Iteration(Arguments...); + MlasLoopUnrollStep::template Step(Arguments...); + } +}; + +template +struct MlasLoopUnrollStep +{ + template + MLAS_FORCEINLINE + static + void + Step( + IterationArgs&&... + ) + { + // Terminate the loop. + } +}; + +template +struct MlasLoopUnroll +{ + template + MLAS_FORCEINLINE + void + operator()( + IterationArgs&&... Arguments + ) + { + MlasLoopUnrollStep::template Step(Arguments...); + } +}; + +// +// Templates used with loop unrolling to perform an action on one row of the +// output. +// + +struct MlasFgemmZeroAccumulators +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4] + ) + { + Accumulators[Row][0] = MLAS_ZERO_FLOAT(); + Accumulators[Row][1] = MLAS_ZERO_FLOAT(); + Accumulators[Row][2] = MLAS_ZERO_FLOAT(); + Accumulators[Row][3] = MLAS_ZERO_FLOAT(); + } +}; + +struct MlasFgemmLoadAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE AElements[RowCount], + const MLAS_GEMMTYPE* A, + size_t lda + ) + { + AElements[Row] = MLAS_LOAD_FLOAT(A + Row * lda); + } +}; + +struct MlasFgemmBroadcastAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE ABroadcast[RowCount], + const MLAS_GEMMTYPE* A, + size_t lda + ) + { + ABroadcast[Row] = MLAS_BROADCAST_FLOAT(A + Row * lda); + } +}; + +template +struct MlasFgemmSplatAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE AElements[RowCount], + MLAS_FLOATTYPE ABroadcast[RowCount] + ) + { + ABroadcast[Row] = vec_splat(AElements[Row], Lane); + } +}; + +struct MlasFgemmMultiplyAddRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE ABroadcast[RowCount], + MLAS_FLOATTYPE BElements[4] + ) + { + Accumulators[Row][0] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[0], Accumulators[Row][0]); + Accumulators[Row][1] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[1], Accumulators[Row][1]); + Accumulators[Row][2] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[2], Accumulators[Row][2]); + Accumulators[Row][3] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[3], Accumulators[Row][3]); + } +}; + +template +MLAS_FORCEINLINE +void +MlasFgemmComputeBlock( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE ABroadcast[RowCount], + const MLAS_GEMMTYPE* B + ) +{ + MLAS_FLOATTYPE BElements[4]; +#if defined(SINGLE) + BElements[0] = MLAS_LOAD_FLOAT(B); + BElements[1] = MLAS_LOAD_FLOAT(B + 4); + BElements[2] = MLAS_LOAD_FLOAT(B + 8); + BElements[3] = MLAS_LOAD_FLOAT(B + 12); +#else + BElements[0] = MLAS_LOAD_FLOAT(B); + BElements[1] = MLAS_LOAD_FLOAT(B + 2); + BElements[2] = MLAS_LOAD_FLOAT(B + 4); + BElements[3] = MLAS_LOAD_FLOAT(B + 6); +#endif + + MlasLoopUnroll()(Accumulators, ABroadcast, BElements); +} + +struct MlasFgemmMultiplyAlphaRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_FLOATTYPE AlphaBroadcast + ) + { + Accumulators[Index] = MLAS_MUL_FLOAT(Accumulators[Index], AlphaBroadcast); + } +}; + +struct MlasFgemmMultiplyAlphaAddRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_FLOATTYPE AlphaBroadcast, + const MLAS_GEMMTYPE* C + ) + { +#if defined(SINGLE) + Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index], + AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 4)); +#else + Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index], + AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 2)); +#endif + } +}; + +struct MlasFgemmStoreRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_GEMMTYPE* C + ) + { +#if defined(SINGLE) + MLAS_STORE_FLOAT(C + Index * 4, Accumulators[Index]); +#else + MLAS_STORE_FLOAT(C + Index * 2, Accumulators[Index]); +#endif + } +}; + +template +struct MlasFgemmStoreVector +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_GEMMTYPE* C, + size_t ldc, + MLAS_FLOATTYPE AlphaBroadcast, + bool ZeroMode + ) + { + MLAS_GEMMTYPE* c = C + Row * ldc; + + if (ZeroMode) { + MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast); + } else { + MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast, c); + } + + MlasLoopUnroll()(Accumulators[Row], c); + + // + // Shift down any unaligned elements to the bottom for further processing. + // + + if (VectorCount < 4) { + Accumulators[Row][0] = Accumulators[Row][VectorCount]; + } + } +}; + +struct MlasFgemmMultiplyAlphaTrailing +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE AlphaBroadcast + ) + { + Accumulators[Row][0] = MLAS_MUL_FLOAT(Accumulators[Row][0], AlphaBroadcast); + } +}; + +template +struct MlasFgemmStoreScalar +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_GEMMTYPE* C, + size_t ldc, + bool ZeroMode + ) + { + MLAS_GEMMTYPE* c = C + Row * ldc + Lane; + MLAS_GEMMTYPE Value = MLAS_EXTRACT_FLOAT(Accumulators[Row][0]); + + if (!ZeroMode) { + Value += *c; + } + + *c = Value; + } +}; + diff --git a/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp b/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp index 9ba4f8062a4a2..bc08af0cd7651 100644 --- a/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp +++ b/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp @@ -188,10 +188,10 @@ MlasSgemmMMAProcessCount( // while (k >= 4) { - MlasLoopUnroll()(AElements, a, lda); + MlasLoopUnroll()(AElements, a, lda); MlasSgemmComputeAElements(AElements, ABroadcast); if (CountM == 8) { - MlasLoopUnroll()(AElements, a + ( lda * 4), lda); + MlasLoopUnroll()(AElements, a + ( lda * 4), lda); MlasSgemmComputeAElements(AElements, A2Broadcast); } MlasSgemmComputeBlockMMA(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM); diff --git a/onnxruntime/core/mlas/lib/power/SgemmKernelpower.h b/onnxruntime/core/mlas/lib/power/SgemmKernelpower.h index 1cd8d7dd16782..53be544bdbe3f 100644 --- a/onnxruntime/core/mlas/lib/power/SgemmKernelpower.h +++ b/onnxruntime/core/mlas/lib/power/SgemmKernelpower.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - SgemmKernelPower.cpp + SgemmKernelpower.h Abstract: @@ -15,286 +15,7 @@ Module Name: --*/ -#include "mlasi.h" - -// -// Templates to ensure that a loop is unrolled. -// - -template -struct MlasLoopUnrollStep -{ - template - MLAS_FORCEINLINE - static - void - Step( - IterationArgs&&... Arguments - ) - { - IterationType::template Iteration(Arguments...); - MlasLoopUnrollStep::template Step(Arguments...); - } -}; - -template -struct MlasLoopUnrollStep -{ - template - MLAS_FORCEINLINE - static - void - Step( - IterationArgs&&... - ) - { - // Terminate the loop. - } -}; - -template -struct MlasLoopUnroll -{ - template - MLAS_FORCEINLINE - void - operator()( - IterationArgs&&... Arguments - ) - { - MlasLoopUnrollStep::template Step(Arguments...); - } -}; - -// -// Templates used with loop unrolling to perform an action on one row of the -// output. -// - -struct MlasSgemmZeroAccumulators -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[RowCount][4] - ) - { - Accumulators[Row][0] = MlasZeroFloat32x4(); - Accumulators[Row][1] = MlasZeroFloat32x4(); - Accumulators[Row][2] = MlasZeroFloat32x4(); - Accumulators[Row][3] = MlasZeroFloat32x4(); - } -}; - -struct MlasSgemmLoadAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 AElements[RowCount], - const float* A, - size_t lda - ) - { - AElements[Row] = MlasLoadFloat32x4(A + Row * lda); - } -}; - -struct MlasSgemmBroadcastAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 ABroadcast[RowCount], - const float* A, - size_t lda - ) - { - ABroadcast[Row] = MlasBroadcastFloat32x4(A + Row * lda); - } -}; - -template -struct MlasSgemmSplatAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 AElements[RowCount], - MLAS_FLOAT32X4 ABroadcast[RowCount] - ) - { - ABroadcast[Row] = vec_splat(AElements[Row], Lane); - } -}; - -struct MlasSgemmMultiplyAddRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[RowCount][4], - MLAS_FLOAT32X4 ABroadcast[RowCount], - MLAS_FLOAT32X4 BElements[4] - ) - { - Accumulators[Row][0] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[0], Accumulators[Row][0]); - Accumulators[Row][1] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[1], Accumulators[Row][1]); - Accumulators[Row][2] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[2], Accumulators[Row][2]); - Accumulators[Row][3] = MlasMultiplyAddFloat32x4(ABroadcast[Row], BElements[3], Accumulators[Row][3]); - } -}; - -template -MLAS_FORCEINLINE -void -MlasSgemmComputeBlock( - MLAS_FLOAT32X4 Accumulators[RowCount][4], - MLAS_FLOAT32X4 ABroadcast[RowCount], - const float* B - ) -{ - MLAS_FLOAT32X4 BElements[4]; - - BElements[0] = MlasLoadFloat32x4(B); - BElements[1] = MlasLoadFloat32x4(B + 4); - BElements[2] = MlasLoadFloat32x4(B + 8); - BElements[3] = MlasLoadFloat32x4(B + 12); - - MlasLoopUnroll()(Accumulators, ABroadcast, BElements); -} - -struct MlasSgemmMultiplyAlphaRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[4], - MLAS_FLOAT32X4 AlphaBroadcast - ) - { - Accumulators[Index] = MlasMultiplyFloat32x4(Accumulators[Index], AlphaBroadcast); - } -}; - -struct MlasSgemmMultiplyAlphaAddRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[4], - MLAS_FLOAT32X4 AlphaBroadcast, - const float* C - ) - { - Accumulators[Index] = MlasMultiplyAddFloat32x4(Accumulators[Index], - AlphaBroadcast, MlasLoadFloat32x4(C + Index * 4)); - } -}; - -struct MlasSgemmStoreRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[4], - float* C - ) - { - MlasStoreFloat32x4(C + Index * 4, Accumulators[Index]); - } -}; - -template -struct MlasSgemmStoreVector -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[RowCount][4], - float* C, - size_t ldc, - MLAS_FLOAT32X4 AlphaBroadcast, - bool ZeroMode - ) - { - float* c = C + Row * ldc; - - if (ZeroMode) { - MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast); - } else { - MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast, c); - } - - MlasLoopUnroll()(Accumulators[Row], c); - - // - // Shift down any unaligned elements to the bottom for further processing. - // - - if (VectorCount < 4) { - Accumulators[Row][0] = Accumulators[Row][VectorCount]; - } - } -}; - -struct MlasSgemmMultiplyAlphaTrailing -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[RowCount][4], - MLAS_FLOAT32X4 AlphaBroadcast - ) - { - Accumulators[Row][0] = MlasMultiplyFloat32x4(Accumulators[Row][0], AlphaBroadcast); - } -}; - -template -struct MlasSgemmStoreScalar -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[RowCount][4], - float* C, - size_t ldc, - bool ZeroMode - ) - { - float* c = C + Row * ldc + Lane; - float Value = MlasExtractLaneFloat32x4(Accumulators[Row][0]); - - if (!ZeroMode) { - Value += *c; - } - - *c = Value; - } -}; +#include "FgemmKernelpower.h" template MLAS_FORCEINLINE @@ -324,7 +45,7 @@ MlasSgemmProcessCount( // Clear the block accumulators. // - MlasLoopUnroll()(Accumulators); + MlasLoopUnroll()(Accumulators); // // Compute the output block. @@ -332,19 +53,19 @@ MlasSgemmProcessCount( while (k >= 4) { - MlasLoopUnroll()(AElements, a, lda); + MlasLoopUnroll()(AElements, a, lda); - MlasLoopUnroll>()(AElements, ABroadcast); - MlasSgemmComputeBlock(Accumulators, ABroadcast, B); + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); - MlasLoopUnroll>()(AElements, ABroadcast); - MlasSgemmComputeBlock(Accumulators, ABroadcast, B + 16); + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 16); - MlasLoopUnroll>()(AElements, ABroadcast); - MlasSgemmComputeBlock(Accumulators, ABroadcast, B + 32); + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 32); - MlasLoopUnroll>()(AElements, ABroadcast); - MlasSgemmComputeBlock(Accumulators, ABroadcast, B + 48); + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 48); a += 4; B += 16 * 4; @@ -353,8 +74,8 @@ MlasSgemmProcessCount( while (k > 0) { - MlasLoopUnroll()(ABroadcast, a, lda); - MlasSgemmComputeBlock(Accumulators, ABroadcast, B); + MlasLoopUnroll()(ABroadcast, a, lda); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); a += 1; B += 16; @@ -367,7 +88,7 @@ MlasSgemmProcessCount( // Store the entire output block. // - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); } else { @@ -376,11 +97,11 @@ MlasSgemmProcessCount( // if (CountN >= 12) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); } else if (CountN >= 8) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); } else if (CountN >= 4) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); } // @@ -392,16 +113,16 @@ MlasSgemmProcessCount( if (CountN > 0) { - MlasLoopUnroll()(Accumulators, AlphaBroadcast); + MlasLoopUnroll()(Accumulators, AlphaBroadcast); - MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); if (CountN >= 2) { - MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); } if (CountN >= 3) { - MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); } }