diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index ed7287cee0..426d4cda79 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,277 +39,248 @@ void bli_cntx_init_zen( cntx_t* cntx ) { - blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - blksz_t thresh[ BLIS_NUM_THRESH ]; - - // Set default kernel blocksizes and functions. - bli_cntx_init_zen_ref( cntx ); - - // ------------------------------------------------------------------------- - - // Update the context with optimized native gemm micro-kernels and - // their storage preferences. - bli_cntx_set_l3_nat_ukrs - ( - 8, - - // gemm - BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, - BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, - BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, - - // gemmtrsm_l - BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, - - // gemmtrsm_u - BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, - cntx - ); - + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_zen_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 8, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + cntx + ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 4, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 20, #if 1 - // Update the context with optimized packm kernels. - bli_cntx_set_packm_kers - ( - 8, - BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, - BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, - BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, - BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, - BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, - BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, - BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, - BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, - cntx - ); + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif - - // Update the context with optimized level-1f kernels. - bli_cntx_set_l1f_kers - ( - 4, - - // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, - - // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, - cntx - ); - - // Update the context with optimized level-1v kernels. - bli_cntx_set_l1v_kers - ( - 10, - - // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, - - // axpyv + // axpyv #if 0 - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int, + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int, #else - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, -#endif + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, -#if 0 - // copyv - BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, - BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, #endif - // dotv BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, - // scalv #if 0 - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int, + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int, #else - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, -#endif - -#if 0 - // setv - BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, - BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, - - // swapv - BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, - BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, #endif - cntx - ); - - // Initialize level-3 blocksize objects with architecture-specific values. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + //set + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); /* - Multi Instance performance improvement of DGEMM when binded to a CCX - In Multi instance each thread runs a sequential DGEMM. + Multi Instance performance improvement of DGEMM when binded to a CCX + In Multi instance each thread runs a sequential DGEMM. - a) If BLIS is run in a multi-instance mode with - CPU freq 2.6/2.2 Ghz - DDR4 clock frequency 2400Mhz - mc = 240, kc = 512, and nc = 2040 - has better performance on EPYC server, over the default block sizes. + a) If BLIS is run in a multi-instance mode with + CPU freq 2.6/2.2 Ghz + DDR4 clock frequency 2400Mhz + mc = 240, kc = 512, and nc = 2040 + has better performance on EPYC server, over the default block sizes. - b) If BLIS is run in Single Instance mode - mc = 510, kc = 1024 and nc = 4080 + b) If BLIS is run in Single Instance mode + mc = 510, kc = 1024 and nc = 4080 */ - #if BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 510, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 1024, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); - #else - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); - #endif - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); - - // Update the context with the current architecture's register and cache - // blocksizes (and multiples) for native execution. - bli_cntx_set_blkszs - ( - BLIS_NAT, 7, - // level-3 - BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, - BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, - BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, - BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, - BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, - // level-1f - BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, - BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, - cntx - ); - - // ------------------------------------------------------------------------- - - // Initialize sup thresholds with architecture-appropriate values. - // s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, -1, -1 ); - bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, -1, -1 ); - bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, -1, -1 ); - - // Initialize the context with the sup thresholds. - bli_cntx_set_l3_sup_thresh - ( - 3, - BLIS_MT, &thresh[ BLIS_MT ], - BLIS_NT, &thresh[ BLIS_NT ], - BLIS_KT, &thresh[ BLIS_KT ], - cntx - ); - - // Initialize the context with the sup handlers. - bli_cntx_set_l3_sup_handlers - ( - 1, - BLIS_GEMM, bli_gemmsup_ref, - //BLIS_GEMMT, bli_gemmtsup_ref, - cntx - ); - - // Update the context with optimized small/unpacked gemm kernels. - bli_cntx_set_l3_sup_kers - ( - 16, - //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, - BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, - BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, - BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - - BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, - BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16m, TRUE, - BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, - BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, - BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m, TRUE, - BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16n, TRUE, - BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, - BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n, TRUE, -#if 0 - BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, - BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, - BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, -#endif +#ifdef BLIS_ENABLE_ZEN_BLOCK_SIZES + #if BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES -#if 0 - // NOTE: This set of kernels is likely broken and therefore disabled. - BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - - BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, -#endif - cntx - ); - - // Initialize level-3 sup blocksize objects with architecture-specific - // values. - // s d c z - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, -1, -1, - 9, 9, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, -1, -1 ); -#if 0 - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, - 9, 9, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); -#endif + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 510, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 1024, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); - // Update the context with the current architecture's register and cache - // blocksizes for small/unpacked level-3 problems. - bli_cntx_set_l3_sup_blkszs - ( - 5, - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); + #else + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); + + #endif +#else + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 4080, 3056 ); +#endif + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + + // Update the context with the current architecture's register and cache + // blocksizes for level-3 TRSM execution. + bli_cntx_set_trsm_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 110 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 1, + BLIS_GEMM, bli_gemmsup_ref, + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 28, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); } diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 0964ce463e..33ca2809dd 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,117 +43,108 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Set default kernel blocksizes and functions. bli_cntx_init_zen2_ref( cntx ); - // ------------------------------------------------------------------------- - - // Update the context with optimized native gemm micro-kernels and - // their storage preferences. - bli_cntx_set_l3_nat_ukrs - ( - 8, - - // gemm - BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, - BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, - BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, - - // gemmtrsm_l - BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, - - // gemmtrsm_u - BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, - cntx - ); - + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 8, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + cntx + ); + + // // packm kernels + // bli_cntx_set_packm_kers + // ( + // 2, + // BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_8xk_gen_zen, + // BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_6xk_gen_zen, + // cntx + // ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 4, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 20, #if 1 - // Update the context with optimized packm kernels. - bli_cntx_set_packm_kers - ( - 8, - BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, - BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, - BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, - BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, - BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, - BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, - BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, - BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, - cntx - ); + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif - - // Update the context with optimized level-1f kernels. - bli_cntx_set_l1f_kers - ( - 4, - - // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, - - // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, - cntx - ); - - // Update the context with optimized level-1v kernels. - bli_cntx_set_l1v_kers - ( - 16, - - // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, - - // axpyv - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, - - // dotv - BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, - BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, - - // dotxv - BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, - - // scalv - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, - - //swap - BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, - BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, - - //copy - BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, - BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, - - //set - BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, - BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, - cntx - ); - - // Initialize level-3 blocksize objects with architecture-specific values. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + + //swap + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + //copy + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + + //set + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); #if AOCL_BLIS_MULTIINSTANCE - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); #else bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); #endif - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); // Update the context with the current architecture's register and cache // blocksizes (and multiples) for native execution. @@ -186,15 +177,15 @@ void bli_cntx_init_zen2( cntx_t* cntx ) bli_blksz_init_easy( &thresh[ BLIS_KT ], 100000, 100000, -1, -1 ); #endif - // Initialize the context with the sup thresholds. - bli_cntx_set_l3_sup_thresh - ( - 3, - BLIS_MT, &thresh[ BLIS_MT ], - BLIS_NT, &thresh[ BLIS_NT ], - BLIS_KT, &thresh[ BLIS_KT ], - cntx - ); + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); #if 0 // Initialize the context with the sup handlers. @@ -268,17 +259,17 @@ void bli_cntx_init_zen2( cntx_t* cntx ) bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, -1, -1 ); - // Update the context with the current architecture's register and cache - // blocksizes for small/unpacked level-3 problems. - bli_cntx_set_l3_sup_blkszs - ( - 5, - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); } diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index b5bbb05ed2..a6a6d9852a 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -4,7 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -108,44 +109,48 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 16, - - // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + 20, +#if 1 + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, +#endif + // axpyv - // axpyv + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, - // axpyv - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, - // dotv - BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, - BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, - // dotxv - BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, - // scalv - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + //swap + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, - //swap - BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, - BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + //copy + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, - //copy - BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, - BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + //set + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, - //set - BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, - BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, - - cntx - ); + cntx + ); // Initialize level-3 blocksize objects with architecture-specific values. // @@ -294,5 +299,4 @@ void bli_cntx_init_zen3( cntx_t* cntx ) BLIS_MR, &blkszs[ BLIS_MR ], cntx ); -} - +} \ No newline at end of file diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index e3c67fd55b..a3635826a9 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -49,40 +49,316 @@ void PASTEF77(ch,blasname) \ ftype* y, const f77_int* incy \ ) \ { \ - dim_t n0; \ - ftype* x0; \ - ftype* y0; \ - inc_t incx0; \ - inc_t incy0; \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ \ - /* Initialize BLIS. */ \ - bli_init_auto(); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ \ - /* Convert/typecast negative values of n to zero. */ \ - bli_convert_blas_dim1( *n, n0 ); \ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ \ - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ \ - bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ - bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ \ - /* Call BLIS interface. */ \ - PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - n0, \ - (ftype*)alpha, \ - x0, incx0, \ - y0, incy0, \ - NULL, \ - NULL \ - ); \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + (ftype*)alpha, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ } -#ifdef BLIS_ENABLE_BLAS +#ifdef BLIS_CONFIG_EPYC +void saxpy_ + ( + const f77_int* n, + const float* alpha, + const float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((float*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + bli_saxpyv_zen_int10( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void daxpy_ + ( + const f77_int* n, + const double* alpha, + const double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((double*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + bli_daxpyv_zen_int10( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void caxpy_ + ( + const f77_int* n, + const scomplex* alpha, + const scomplex* x, const f77_int* incx, + scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + bli_caxpyv_zen_int5( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void zaxpy_ + ( + const f77_int* n, + const dcomplex* alpha, + const dcomplex* x, const f77_int* incx, + dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + bli_zaxpyv_zen_int5( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +#else INSERT_GENTFUNC_BLAS( axpy, axpyv ) #endif diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 0699cb22fd..c5dc7b2135 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -1,11 +1,9 @@ /* - BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -17,7 +15,6 @@ - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -29,12 +26,10 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ #include "blis.h" -#ifdef BLIS_ENABLE_BLAS // // Define BLAS-to-BLIS interfaces. @@ -49,51 +44,526 @@ ftype PASTEF772(ch,blasname,chc) \ const ftype* y, const f77_int* incy \ ) \ { \ - dim_t n0; \ - ftype* x0; \ - ftype* y0; \ - inc_t incx0; \ - inc_t incy0; \ - ftype rho; \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ \ - /* Initialize BLIS. */ \ - bli_init_auto(); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ \ - /* Convert/typecast negative values of n to zero. */ \ - bli_convert_blas_dim1( *n, n0 ); \ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ \ - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ \ - bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ - bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ \ - /* Call BLIS interface. */ \ - PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ - ( \ - blis_conjx, \ - BLIS_NO_CONJUGATE, \ - n0, \ - x0, incx0, \ - y0, incy0, \ - &rho, \ - NULL, \ - NULL \ - ); \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ \ - return rho; \ + return rho; \ +} + +#ifdef BLIS_ENABLE_BLAS +#ifdef BLIS_CONFIG_EPYC +float sdot_ + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + float rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((float*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + /* Call BLIS kernel. */ + bli_sdotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + + /* Finalize BLIS. */ +// bli_finalize_auto(); + return rho; } +double ddot_ + ( + const f77_int* n, + const double* x, const f77_int* incx, + const double* y, const f77_int* incy + ) +{ + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + double rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((double*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + /* Call BLIS kernel. */ + bli_ddotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + + /* Finalize BLIS. */ +// bli_finalize_auto(); + return rho; +} +#else INSERT_GENTFUNCDOTR_BLAS( dot, dotv ) +#endif +#ifdef BLIS_ENABLE_BLAS #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL +#ifdef BLIS_CONFIG_EPYC +scomplex cdotu_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + scomplex rho; -INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) + /* Initialize BLIS. */ +// bli_init_auto(); -#else + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + + /* Finalize BLIS. */ +// bli_finalize_auto(); + return rho; +} + +dcomplex zdotu_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + dcomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + return rho; +} + + +scomplex cdotc_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + scomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + return rho; +} +dcomplex zdotc_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + dcomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + + return rho; +} +#else +INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) +#endif +#else // For the "intel" complex return type, use a hidden parameter to return the result #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ @@ -106,45 +576,45 @@ void PASTEF772(ch,blasname,chc) \ const ftype* y, const f77_int* incy \ ) \ { \ - dim_t n0; \ - ftype* x0; \ - ftype* y0; \ - inc_t incx0; \ - inc_t incy0; \ - ftype rho; \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ \ - /* Initialize BLIS. */ \ - bli_init_auto(); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ \ - /* Convert/typecast negative values of n to zero. */ \ - bli_convert_blas_dim1( *n, n0 ); \ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ \ - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ \ - bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ - bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ \ - /* Call BLIS interface. */ \ - PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ - ( \ - blis_conjx, \ - BLIS_NO_CONJUGATE, \ - n0, \ - x0, incx0, \ - y0, incy0, \ - &rho, \ - NULL, \ - NULL \ - ); \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ \ - *rhop = rho; \ + *rhop = rho; \ } INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) - +#endif #endif @@ -160,16 +630,16 @@ float PASTEF77(sd,sdot) const float* y, const f77_int* incy ) { - return ( float ) - ( - ( double )(*sb) + - PASTEF77(d,sdot) - ( - n, - x, incx, - y, incy - ) - ); + return ( float ) + ( + ( double )(*sb) + + PASTEF77(d,sdot) + ( + n, + x, incx, + y, incy + ) + ); } // Input vectors stored in single precision, computed in double precision, @@ -181,40 +651,39 @@ double PASTEF77(d,sdot) const float* y, const f77_int* incy ) { - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - double rho; - dim_t i; + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + double rho; + dim_t i; - /* Initialization of BLIS is not required. */ + /* Initialization of BLIS is not required. */ - /* Convert/typecast negative values of n to zero. */ - bli_convert_blas_dim1( *n, n0 ); + /* Convert/typecast negative values of n to zero. */ + bli_convert_blas_dim1( *n, n0 ); - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - bli_convert_blas_incv( n0, (float*)x, *incx, x0, incx0 ); - bli_convert_blas_incv( n0, (float*)y, *incy, y0, incy0 ); + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + bli_convert_blas_incv( n0, (float*)x, *incx, x0, incx0 ); + bli_convert_blas_incv( n0, (float*)y, *incy, y0, incy0 ); - rho = 0.0; + rho = 0.0; - for ( i = 0; i < n0; i++ ) - { - float* chi1 = x0 + (i )*incx0; - float* psi1 = y0 + (i )*incy0; + for ( i = 0; i < n0; i++ ) + { + float* chi1 = x0 + (i )*incx0; + float* psi1 = y0 + (i )*incy0; - bli_ddots( (( double )(*chi1)), - (( double )(*psi1)), rho ); - } + bli_ddots( (( double )(*chi1)), + (( double )(*psi1)), rho ); + } - /* Finalization of BLIS is not required, because initialization was - not required. */ + /* Finalization of BLIS is not required, because initialization was + not required. */ - return rho; + return rho; } -#endif - +#endif \ No newline at end of file diff --git a/kernels/zen/1/bli_axpyv_zen_int10.c b/kernels/zen/1/bli_axpyv_zen_int10.c index 873b7da536..936921c6bf 100644 --- a/kernels/zen/1/bli_axpyv_zen_int10.c +++ b/kernels/zen/1/bli_axpyv_zen_int10.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. - Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2018 - 2020, The University of Texas at Austin. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,16 +41,16 @@ One 256-bit AVX register holds 8 SP elements. */ typedef union { - __m256 v; - float f[8] __attribute__((aligned(64))); + __m256 v; + float f[8] __attribute__((aligned(64))); } v8sf_t; /* Union data structure to access AVX registers * One 256-bit AVX register holds 4 DP elements. */ typedef union { - __m256d v; - double d[4] __attribute__((aligned(64))); + __m256d v; + double d[4] __attribute__((aligned(64))); } v4df_t; // ----------------------------------------------------------------------------- @@ -65,198 +65,198 @@ void bli_saxpyv_zen_int10 cntx_t* restrict cntx ) { - const dim_t n_elem_per_reg = 8; - - dim_t i; - - float* restrict x0; - float* restrict y0; - - __m256 alphav; - __m256 xv[10]; - __m256 yv[10]; - __m256 zv[10]; - - // If the vector dimension is zero, or if alpha is zero, return early. - if ( bli_zero_dim1( n ) || PASTEMAC(s,eq0)( *alpha ) ) return; - - // Initialize local pointers. - x0 = x; - y0 = y; - - if ( incx == 1 && incy == 1 ) - { - // Broadcast the alpha scalar to all elements of a vector register. - alphav = _mm256_broadcast_ss( alpha ); - - for ( i = 0; (i + 79) < n; i += 80 ) - { - // 80 elements will be processed per loop; 10 FMAs will run per loop. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); - yv[5] = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); - yv[6] = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); - yv[7] = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); - yv[8] = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); - yv[9] = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_ps( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_ps( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_ps( xv[3], alphav, yv[3] ); - zv[4] = _mm256_fmadd_ps( xv[4], alphav, yv[4] ); - zv[5] = _mm256_fmadd_ps( xv[5], alphav, yv[5] ); - zv[6] = _mm256_fmadd_ps( xv[6], alphav, yv[6] ); - zv[7] = _mm256_fmadd_ps( xv[7], alphav, yv[7] ); - zv[8] = _mm256_fmadd_ps( xv[8], alphav, yv[8] ); - zv[9] = _mm256_fmadd_ps( xv[9], alphav, yv[9] ); - - _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_ps( (y0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_ps( (y0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_ps( (y0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_ps( (y0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_ps( (y0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_ps( (y0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; - y0 += 10*n_elem_per_reg; - } - - for ( ; (i + 39) < n; i += 40 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_ps( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_ps( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_ps( xv[3], alphav, yv[3] ); - zv[4] = _mm256_fmadd_ps( xv[4], alphav, yv[4] ); - - _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_ps( (y0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; - y0 += 5*n_elem_per_reg; - } - - for ( ; (i + 31) < n; i += 32 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_ps( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_ps( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_ps( xv[3], alphav, yv[3] ); - - _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), zv[3] ); - - x0 += 4*n_elem_per_reg; - y0 += 4*n_elem_per_reg; - } - - for ( ; (i + 15) < n; i += 16 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_ps( xv[1], alphav, yv[1] ); - - _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), zv[1] ); - - x0 += 2*n_elem_per_reg; - y0 += 2*n_elem_per_reg; - } - - for ( ; (i + 7) < n; i += 8 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); - - _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); - - x0 += 1*n_elem_per_reg; - y0 += 1*n_elem_per_reg; - } - - // Issue vzeroupper instruction to clear upper lanes of ymm registers. - // This avoids a performance penalty caused by false dependencies when - // transitioning from from AVX to SSE instructions (which may occur - // as soon as the n_left cleanup loop below if BLIS is compiled with - // -mfpmath=sse). - _mm256_zeroupper(); - - for ( ; (i + 0) < n; i += 1 ) - { - *y0 += (*alpha) * (*x0); - - x0 += 1; - y0 += 1; - } - } - else - { - const float alphac = *alpha; - - for ( i = 0; i < n; ++i ) - { - const float x0c = *x0; - - *y0 += alphac * x0c; - - x0 += incx; - y0 += incy; - } - } + const dim_t n_elem_per_reg = 8; + + dim_t i; + + float* restrict x0; + float* restrict y0; + + __m256 alphav; + __m256 xv[10]; + __m256 yv[10]; + __m256 zv[10]; + + // If the vector dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(s,eq0)( *alpha ) ) return; + + // Initialize local pointers. + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // Broadcast the alpha scalar to all elements of a vector register. + alphav = _mm256_broadcast_ss( alpha ); + + for ( i = 0; (i + 79) < n; i += 80 ) + { + // 80 elements will be processed per loop; 10 FMAs will run per loop. + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); + yv[8] = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); + yv[9] = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); + zv[1] = _mm256_fmadd_ps( xv[1], alphav, yv[1] ); + zv[2] = _mm256_fmadd_ps( xv[2], alphav, yv[2] ); + zv[3] = _mm256_fmadd_ps( xv[3], alphav, yv[3] ); + zv[4] = _mm256_fmadd_ps( xv[4], alphav, yv[4] ); + zv[5] = _mm256_fmadd_ps( xv[5], alphav, yv[5] ); + zv[6] = _mm256_fmadd_ps( xv[6], alphav, yv[6] ); + zv[7] = _mm256_fmadd_ps( xv[7], alphav, yv[7] ); + zv[8] = _mm256_fmadd_ps( xv[8], alphav, yv[8] ); + zv[9] = _mm256_fmadd_ps( xv[9], alphav, yv[9] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), zv[3] ); + _mm256_storeu_ps( (y0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (y0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_ps( (y0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_ps( (y0 + 7*n_elem_per_reg), zv[7] ); + _mm256_storeu_ps( (y0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_ps( (y0 + 9*n_elem_per_reg), zv[9] ); + + x0 += 10*n_elem_per_reg; + y0 += 10*n_elem_per_reg; + } + + for ( ; (i + 39) < n; i += 40 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); + zv[1] = _mm256_fmadd_ps( xv[1], alphav, yv[1] ); + zv[2] = _mm256_fmadd_ps( xv[2], alphav, yv[2] ); + zv[3] = _mm256_fmadd_ps( xv[3], alphav, yv[3] ); + zv[4] = _mm256_fmadd_ps( xv[4], alphav, yv[4] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), zv[3] ); + _mm256_storeu_ps( (y0 + 4*n_elem_per_reg), zv[4] ); + + x0 += 5*n_elem_per_reg; + y0 += 5*n_elem_per_reg; + } + + for ( ; (i + 31) < n; i += 32 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); + zv[1] = _mm256_fmadd_ps( xv[1], alphav, yv[1] ); + zv[2] = _mm256_fmadd_ps( xv[2], alphav, yv[2] ); + zv[3] = _mm256_fmadd_ps( xv[3], alphav, yv[3] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), zv[3] ); + + x0 += 4*n_elem_per_reg; + y0 += 4*n_elem_per_reg; + } + + for ( ; (i + 15) < n; i += 16 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); + zv[1] = _mm256_fmadd_ps( xv[1], alphav, yv[1] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), zv[1] ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_ps( xv[0], alphav, yv[0] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + for ( ; (i + 0) < n; i += 1 ) + { + *y0 += (*alpha) * (*x0); + + x0 += 1; + y0 += 1; + } + } + else + { + const float alphac = *alpha; + + for ( i = 0; i < n; ++i ) + { + const float x0c = *x0; + + *y0 += alphac * x0c; + + x0 += incx; + y0 += incy; + } + } } // ----------------------------------------------------------------------------- @@ -271,197 +271,790 @@ void bli_daxpyv_zen_int10 cntx_t* restrict cntx ) { - const dim_t n_elem_per_reg = 4; - - dim_t i; - - double* restrict x0 = x; - double* restrict y0 = y; - - __m256d alphav; - __m256d xv[10]; - __m256d yv[10]; - __m256d zv[10]; - - // If the vector dimension is zero, or if alpha is zero, return early. - if ( bli_zero_dim1( n ) || PASTEMAC(d,eq0)( *alpha ) ) return; - - // Initialize local pointers. - x0 = x; - y0 = y; - - if ( incx == 1 && incy == 1 ) - { - // Broadcast the alpha scalar to all elements of a vector register. - alphav = _mm256_broadcast_sd( alpha ); - - for ( i = 0; (i + 39) < n; i += 40 ) - { - // 40 elements will be processed per loop; 10 FMAs will run per loop. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); - yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); - yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); - yv[8] = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); - yv[9] = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); - zv[4] = _mm256_fmadd_pd( xv[4], alphav, yv[4] ); - zv[5] = _mm256_fmadd_pd( xv[5], alphav, yv[5] ); - zv[6] = _mm256_fmadd_pd( xv[6], alphav, yv[6] ); - zv[7] = _mm256_fmadd_pd( xv[7], alphav, yv[7] ); - zv[8] = _mm256_fmadd_pd( xv[8], alphav, yv[8] ); - zv[9] = _mm256_fmadd_pd( xv[9], alphav, yv[9] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_pd( (y0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_pd( (y0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_pd( (y0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_pd( (y0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_pd( (y0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; - y0 += 10*n_elem_per_reg; - } - - for ( ; (i + 19) < n; i += 20 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); - zv[4] = _mm256_fmadd_pd( xv[4], alphav, yv[4] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; - y0 += 5*n_elem_per_reg; - } - - for ( ; (i + 15) < n; i += 16 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); - - x0 += 4*n_elem_per_reg; - y0 += 4*n_elem_per_reg; - } - - for ( ; i + 7 < n; i += 8 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); - - x0 += 2*n_elem_per_reg; - y0 += 2*n_elem_per_reg; - } - - for ( ; i + 3 < n; i += 4 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - - x0 += 1*n_elem_per_reg; - y0 += 1*n_elem_per_reg; - } - - // Issue vzeroupper instruction to clear upper lanes of ymm registers. - // This avoids a performance penalty caused by false dependencies when - // transitioning from from AVX to SSE instructions (which may occur - // as soon as the n_left cleanup loop below if BLIS is compiled with - // -mfpmath=sse). - _mm256_zeroupper(); - - for ( ; i < n; i += 1 ) - { - *y0 += (*alpha) * (*x0); - - y0 += 1; - x0 += 1; - } - } - else - { - const double alphac = *alpha; - - for ( i = 0; i < n; ++i ) - { - const double x0c = *x0; - - *y0 += alphac * x0c; - - x0 += incx; - y0 += incy; - } - } + const dim_t n_elem_per_reg = 4; + + dim_t i; + + double* restrict x0 = x; + double* restrict y0 = y; + + __m256d alphav; + __m256d xv[10]; + __m256d yv[10]; + __m256d zv[10]; + + // If the vector dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(d,eq0)( *alpha ) ) return; + + // Initialize local pointers. + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // Broadcast the alpha scalar to all elements of a vector register. + alphav = _mm256_broadcast_sd( alpha ); + + for ( i = 0; (i + 39) < n; i += 40 ) + { + // 40 elements will be processed per loop; 10 FMAs will run per loop. + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); + yv[8] = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); + yv[9] = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); + zv[4] = _mm256_fmadd_pd( xv[4], alphav, yv[4] ); + zv[5] = _mm256_fmadd_pd( xv[5], alphav, yv[5] ); + zv[6] = _mm256_fmadd_pd( xv[6], alphav, yv[6] ); + zv[7] = _mm256_fmadd_pd( xv[7], alphav, yv[7] ); + zv[8] = _mm256_fmadd_pd( xv[8], alphav, yv[8] ); + zv[9] = _mm256_fmadd_pd( xv[9], alphav, yv[9] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); + _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (y0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (y0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (y0 + 7*n_elem_per_reg), zv[7] ); + _mm256_storeu_pd( (y0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_pd( (y0 + 9*n_elem_per_reg), zv[9] ); + + x0 += 10*n_elem_per_reg; + y0 += 10*n_elem_per_reg; + } + + for ( ; (i + 19) < n; i += 20 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); + zv[4] = _mm256_fmadd_pd( xv[4], alphav, yv[4] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); + _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), zv[4] ); + + x0 += 5*n_elem_per_reg; + y0 += 5*n_elem_per_reg; + } + + for ( ; (i + 15) < n; i += 16 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); + + x0 += 4*n_elem_per_reg; + y0 += 4*n_elem_per_reg; + } + + for ( ; i + 7 < n; i += 8 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + for ( ; i + 3 < n; i += 4 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + for ( ; i < n; i += 1 ) + { + *y0 += (*alpha) * (*x0); + + y0 += 1; + x0 += 1; + } + } + else + { + const double alphac = *alpha; + + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + + *y0 += alphac * x0c; + + x0 += incx; + y0 += incy; + } + } } +// ----------------------------------------------------------------------------- + +void bli_caxpyv_zen_int5 + ( + conj_t conjx, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 8; + + dim_t i; + + float* restrict x0; + float* restrict y0; + float* restrict alpha0; + + float alphaR, alphaI; + + //scomplex alpha => aR + aI i + __m256 alphaRv; // for braodcast vector aR (real part of alpha) + __m256 alphaIv; // for braodcast vector aI (imaginary part of alpha) + __m256 xv[10]; + __m256 xShufv[10]; + __m256 yv[10]; + + conj_t conjx_use = conjx; + + // If the vector dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + // Initialize local pointers. + x0 = (float*)x; + y0 = (float*)y; + alpha0 = (float*)alpha; + + alphaR = alpha->real; + alphaI = alpha->imag; + + if ( incx == 1 && incy == 1 ) + { + // Broadcast the alpha scalar to all elements of a vector register. + if ( !bli_is_conj (conjx) ) // If BLIS_NO_CONJUGATE + { + alphaRv = _mm256_broadcast_ss( &alphaR ); + + alphaIv = _mm256_set_ps(alphaI, -alphaI, alphaI, -alphaI, alphaI, -alphaI, alphaI, -alphaI); + + } + else + { + alphaIv = _mm256_broadcast_ss( &alphaI ); + + alphaRv = _mm256_set_ps(-alphaR, alphaR, -alphaR, alphaR, -alphaR, alphaR, -alphaR, alphaR); + } + + //----------Scalar algorithm BLIS_NO_CONJUGATE arg------------- + // y = alpha*x + y + // y = (aR + aIi) * (xR + xIi) + (yR + yIi) + // y = aR.xR + aR.xIi + aIi.xR - aIxI + (yR + yIi) + // y = aR.xR - aIxI + yR + aR.xIi + xR.aIi + yIi + // y = ( aR.xR - aIxI + yR ) + ( aR.xI + aI.xR + yI )i + + // SIMD algorithm + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 (shuffle xv) + // arv = aR aR aR aR aR aR aR aR + // aiv = aI -aI aI -aI aI -aI aI -aI + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + + + //----------Scalar algorithm for BLIS_CONJUGATE arg------------- + // y = alpha*conj(x) + y + // y = (aR + aIi) * (xR - xIi) + (yR + yIi) + // y = aR.xR - aR.xIi + aIi.xR + aIxI + (yR + yIi) + // y = aR.xR + aIxI + yR - aR.xIi + xR.aIi + yIi + // y = ( aR.xR + aIxI + yR ) + ( -aR.xI + aI.xR + yI )i + // y = ( aR.xR + aIxI + yR ) + (aI.xR - aR.xI + yI)i + + // SIMD algorithm + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + // arv = aR -aR aR -aR aR -aR aR -aR + // aiv = aI aI aI aI aI aI aI aI + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + + // step 1 : Shuffle xv vector -> xv' + // step 2 : fma yv = arv*xv + yv + // step 3 : fma yv = aiv*xv' + yv (old) + // yv = aiv*xv' + arv*xv + yv + + for ( i= 0 ; (i + 19) < n; i += 20 ) + { + // 20 elements will be processed per loop; 10 FMAs will run per loop. + + // alphaRv = aR aR aR aR aR aR aR aR + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xShufv[0] = _mm256_permute_ps( xv[0], 0xB1); + xShufv[1] = _mm256_permute_ps( xv[1], 0xB1); + xShufv[2] = _mm256_permute_ps( xv[2], 0xB1); + xShufv[3] = _mm256_permute_ps( xv[3], 0xB1); + xShufv[4] = _mm256_permute_ps( xv[4], 0xB1); + + // alphaIv = -aI aI -aI aI -aI aI -aI aI + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_ps( xv[0], alphaRv ,yv[0]); + yv[1] = _mm256_fmadd_ps( xv[1], alphaRv ,yv[1]); + yv[2] = _mm256_fmadd_ps( xv[2], alphaRv ,yv[2]); + yv[3] = _mm256_fmadd_ps( xv[3], alphaRv ,yv[3]); + yv[4] = _mm256_fmadd_ps( xv[4], alphaRv ,yv[4]); + + // yv = ai*xv' + yv (old) + // yv = ai*xv' + ar*xv + yv + // = -aI*xI1 + aR.xR1 + yR1, aI.xR1 + aR.xI1 + yI1, ......... + yv[0] = _mm256_fmadd_ps( xShufv[0], alphaIv, yv[0]); + yv[1] = _mm256_fmadd_ps( xShufv[1], alphaIv, yv[1]); + yv[2] = _mm256_fmadd_ps( xShufv[2], alphaIv, yv[2]); + yv[3] = _mm256_fmadd_ps( xShufv[3], alphaIv, yv[3]); + yv[4] = _mm256_fmadd_ps( xShufv[4], alphaIv, yv[4]); + + // Store back the results + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), yv[3] ); + _mm256_storeu_ps( (y0 + 4*n_elem_per_reg), yv[4] ); + + x0 += 5*n_elem_per_reg; + y0 += 5*n_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + // alphaRv = aR aR aR aR aR aR aR aR + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xShufv[0] = _mm256_permute_ps( xv[0], 0xB1); + xShufv[1] = _mm256_permute_ps( xv[1], 0xB1); + + // alphaIv = -aI aI -aI aI -aI aI -aI aI + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_ps( xv[0], alphaRv ,yv[0]); + yv[1] = _mm256_fmadd_ps( xv[1], alphaRv ,yv[1]); + + // yv = ai*xv' + yv (old) + // yv = ai*xv' + ar*xv + yv + // = -aI*xI1 + aR.xR1 + yR1, aI.xR1 + aR.xI1 + yI1, ......... + yv[0] = _mm256_fmadd_ps( xShufv[0], alphaIv, yv[0]); + yv[1] = _mm256_fmadd_ps( xShufv[1], alphaIv, yv[1]); + + // Store back the result + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + // alphaRv = aR aR aR aR aR aR aR aR + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xShufv[0] = _mm256_permute_ps( xv[0], 0xB1); + + // alphaIv = -aI aI -aI aI -aI aI -aI aI + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_ps( xv[0], alphaRv ,yv[0]); + + // yv = ai*xv' + yv (old) + // yv = ai*xv' + ar*xv + yv + // = aR.xR1 - aI*xI1 + yR1, aR.xI1 + aI.xR1 + yI1 + yv[0] = _mm256_fmadd_ps( xShufv[0], alphaIv, yv[0]); + + // Store back the result + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + /* Residual values are calculated here + y0 += (alpha) * (x0); --> BLIS_NO_CONJUGATE + y0 += ( aR.xR - aIxI + yR ) + ( aR.xI + aI.xR + yI )i + + y0 += (alpha) * conjx(x0); --> BLIS_CONJUGATE + y0 = ( aR.xR + aIxI + yR ) + (aI.xR - aR.xI + yI)i */ + + if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE + { + for ( ; (i + 0) < n; i += 1 ) + { + // real part: ( aR.xR - aIxI + yR ) + *y0 += *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1)); + // img part: ( aR.xI + aI.xR + yI ) + *(y0 + 1) += *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0); + x0 += 2; + y0 += 2; + } + } + else // BLIS_CONJUGATE + { + for ( ; (i + 0) < n; i += 1 ) + { + // real part: ( aR.xR + aIxI + yR ) + *y0 += *alpha0 * (*x0) + (*(alpha0 + 1)) * (*(x0+1)); + // img part: ( aI.xR - aR.xI + yI ) + *(y0 + 1) += (*(alpha0 + 1)) * (*x0) - (*alpha0) * (*(x0+1)); + x0 += 2; + y0 += 2; + } + } + + } + else + { + const float alphar = *alpha0; + const float alphai = *(alpha0 + 1); + + if ( !bli_is_conj(conjx_use) ) + { + for ( i = 0; i < n; ++i ) + { + const float x0c = *x0; + const float x1c = *( x0+1 ); + + *y0 += alphar * x0c - alphai * x1c; + *(y0 + 1) += alphar * x1c + alphai * x0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n; ++i ) + { + const float x0c = *x0; + const float x1c = *( x0+1 ); + + *y0 += alphar * x0c + alphai * x1c; + *(y0 + 1) += alphai * x0c - alphar * x1c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + + } +} + +// ----------------------------------------------------------------------------- + +void bli_zaxpyv_zen_int5 + ( + conj_t conjx, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 4; + + dim_t i; + + double* restrict x0; + double* restrict y0; + double* restrict alpha0; + + double alphaR, alphaI; + + __m256d alphaRv; // for braodcast vector aR (real part of alpha) + __m256d alphaIv; // for braodcast vector aI (imaginary part of alpha) + __m256d xv[5]; + __m256d xShufv[5]; + __m256d yv[5]; + + conj_t conjx_use = conjx; + + // If the vector dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + // Initialize local pointers. + x0 = (double*)x; + y0 = (double*)y; + alpha0 = (double*)alpha; + + alphaR = alpha->real; + alphaI = alpha->imag; + + if ( incx == 1 && incy == 1 ) + { + // Broadcast the alpha scalar to all elements of a vector register. + if ( !bli_is_conj (conjx) ) // If BLIS_NO_CONJUGATE + { + alphaRv = _mm256_broadcast_sd( &alphaR ); + + alphaIv[0] = -alphaI; + alphaIv[1] = alphaI; + alphaIv[2] = -alphaI; + alphaIv[3] = alphaI; + } + else + { + alphaIv = _mm256_broadcast_sd( &alphaI ); + + alphaRv[0] = alphaR; + alphaRv[1] = -alphaR; + alphaRv[2] = alphaR; + alphaRv[3] = -alphaR; + } + + // --------Scalar algorithm BLIS_NO_CONJUGATE arg------------- + // y = alpha*x + y + // y = (aR + aIi) * (xR + xIi) + (yR + yIi) + // y = aR.xR + aR.xIi + aIi.xR - aIxI + (yR + yIi) + // y = aR.xR - aIxI + yR + aR.xIi + xR.aIi + yIi + // y = ( aR.xR - aIxI + yR ) + ( aR.xI + aI.xR + yI )i + + // SIMD algorithm + // xv = xR1 xI1 xR2 xI2 + // xv' = xI1 xR1 xI2 xR2 + // arv = aR aR aR aR + // aiv = -aI aI -aI aI + // yv = yR1 yI1 yR2 yI2 + + // S1 : xv' = xI1 xR1 xI2 xR2 (Shuffle) + // S2 : reg0 = (aR.xR1 aR.xI1 aR.xR2 aR.xI2) + yv + // S3 : reg1 = (-aI.xI1 aI.xR1 -aI.xI2 aI.xR2) + reg0 + //---------------------------------------------------------------- + // Ans : aR.xR1 -aI.xI1 + yR1, aR.xI1 + aI.xR1 + yI1, aR.xR2 -aI.xI2 + yR2, aR.xI2 + aI.xR2 + yI2 + + //----------Scalar algorithm for BLIS_CONJUGATE arg------------- + // y = alpha*conj(x) + y + // y = (aR + aIi) * (xR - xIi) + (yR + yIi) + // y = aR.xR - aR.xIi + aIi.xR + aIxI + (yR + yIi) + // y = aR.xR + aIxI + yR - aR.xIi + xR.aIi + yIi + // y = ( aR.xR + aIxI + yR ) + ( -aR.xI + aI.xR + yI )i + // y = ( aR.xR + aIxI + yR ) + (aI.xR - aR.xI + yI)i + + // SIMD algorithm + // xv = xR1 xI1 xR2 xI2 + // xv' = xI1 xR1 xI2 xR2 + // arv = aR -aR aR -aR + // aiv = aI aI aI aI + // yv = yR1 yI1 yR2 yI2 + + // step 1 : Shuffle xv vector + // reg xv : xv' = xI1 xR1 xI2 xR2 + // step 2 : fma :yv = ar*xv + yv = ar*xv + yv + // step 3 : fma :yv = ai*xv' + yv (old) + // yv = ai*xv' + ar*xv + yv + + for ( i = 0; (i + 9) < n; i += 10 ) + { + // 10 elements will be processed per loop; 10 FMAs will run per loop. + + // alphaRv = aR aR aR aR + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + + // xv' = xI1 xRI xI2 xR2 + xShufv[0] = _mm256_permute_pd( xv[0], 5); + xShufv[1] = _mm256_permute_pd( xv[1], 5); + xShufv[2] = _mm256_permute_pd( xv[2], 5); + xShufv[3] = _mm256_permute_pd( xv[3], 5); + xShufv[4] = _mm256_permute_pd( xv[4], 5); + + // alphaIv = -aI aI -aI aI + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_pd( xv[0], alphaRv ,yv[0]); + yv[1] = _mm256_fmadd_pd( xv[1], alphaRv ,yv[1]); + yv[2] = _mm256_fmadd_pd( xv[2], alphaRv ,yv[2]); + yv[3] = _mm256_fmadd_pd( xv[3], alphaRv ,yv[3]); + yv[4] = _mm256_fmadd_pd( xv[4], alphaRv ,yv[4]); + + // yv = ai*xv' + yv (old) + // yv = ai*xv' + ar*xv + yv + // = -aI*xI1 + aR.xR1 + yR1, aI.xR1 + aR.xI1 + yI1, ......... + yv[0] = _mm256_fmadd_pd( xShufv[0], alphaIv, yv[0]); + yv[1] = _mm256_fmadd_pd( xShufv[1], alphaIv, yv[1]); + yv[2] = _mm256_fmadd_pd( xShufv[2], alphaIv, yv[2]); + yv[3] = _mm256_fmadd_pd( xShufv[3], alphaIv, yv[3]); + yv[4] = _mm256_fmadd_pd( xShufv[4], alphaIv, yv[4]); + + // Store back the result + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3] ); + _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), yv[4] ); + + x0 += 5*n_elem_per_reg; + y0 += 5*n_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + // alphaRv = aR aR aR aR + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // xv' = xI1 xRI xI2 xR2 + xShufv[0] = _mm256_permute_pd( xv[0], 5); + xShufv[1] = _mm256_permute_pd( xv[1], 5); + + // alphaIv = -aI aI -aI aI + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_pd( xv[0], alphaRv ,yv[0]); + yv[1] = _mm256_fmadd_pd( xv[1], alphaRv ,yv[1]); + + // yv = ai*xv' + yv (old) + // yv = ai*xv' + ar*xv + yv + // = -aI*xI1 + aR.xR1 + yR1, aI.xR1 + aR.xI1 + yI1, ......... + yv[0] = _mm256_fmadd_pd( xShufv[0], alphaIv, yv[0]); + yv[1] = _mm256_fmadd_pd( xShufv[1], alphaIv, yv[1]); + + // Store back the result + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 2 ) + { + // alphaRv = aR aR aR aR + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // xv' = xI1 xRI xI2 xR2 + xShufv[0] = _mm256_permute_pd( xv[0], 5); + + // alphaIv = -aI aI -aI aI + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_pd( xv[0], alphaRv ,yv[0]); + + // yv = ai*xv' + yv (old) + // yv = ai*xv' + ar*xv + yv + // = -aI*xI1 + aR.xR1 + yR1, aI.xR1 + aR.xI1 + yI1, ......... + yv[0] = _mm256_fmadd_pd( xShufv[0], alphaIv, yv[0]); + + // Store back the result + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + /* Residual values are calculated here + y0 += (alpha) * (x0); --> BLIS_NO_CONJUGATE + y0 += ( aR.xR - aIxI + yR ) + ( aR.xI + aI.xR + yI )i + + y0 += (alpha) * conjx(x0); --> BLIS_CONJUGATE + y0 = ( aR.xR + aIxI + yR ) + (aI.xR - aR.xI + yI)i */ + + if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE + { + for ( ; (i + 0) < n; i += 1 ) + { + // real part: ( aR.xR - aIxI + yR ) + *y0 += *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1)); + // img part: ( aR.xI + aI.xR + yI ) + *(y0 + 1) += *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0); + x0 += 2; + y0 += 2; + } + } + else // BLIS_CONJUGATE + { + for ( ; (i + 0) < n; i += 1 ) + { + // real part: ( aR.xR + aIxI + yR ) + *y0 += *alpha0 * (*x0) + (*(alpha0 + 1)) * (*(x0+1)); + // img part: ( aI.xR - aR.xI + yI ) + *(y0 + 1) += (*(alpha0 + 1)) * (*x0) - (*alpha0) * (*(x0+1)); + x0 += 2; + y0 += 2; + } + } + } + else + { + const double alphar = *alpha0; + const double alphai = *(alpha0 + 1); + + if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double x1c = *( x0+1 ); + + *y0 += alphar * x0c - alphai * x1c; + *(y0 + 1) += alphar * x1c + alphai * x0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + else // BLIS_CONJUGATE + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double x1c = *( x0+1 ); + + *y0 += alphar * x0c + alphai * x1c; + *(y0 + 1) += alphai * x0c - alphar * x1c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + } +} \ No newline at end of file diff --git a/kernels/zen/1/bli_dotv_zen_int10.c b/kernels/zen/1/bli_dotv_zen_int10.c index 8c445849b0..a8704a4bde 100644 --- a/kernels/zen/1/bli_dotv_zen_int10.c +++ b/kernels/zen/1/bli_dotv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2020, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -41,16 +41,16 @@ One 256-bit AVX register holds 8 SP elements. */ typedef union { - __m256 v; - float f[8] __attribute__((aligned(64))); + __m256 v; + float f[8] __attribute__((aligned(64))); } v8sf_t; /* Union data structure to access AVX registers * One 256-bit AVX register holds 4 DP elements. */ typedef union { - __m256d v; - double d[4] __attribute__((aligned(64))); + __m256d v; + double d[4] __attribute__((aligned(64))); } v4df_t; // ----------------------------------------------------------------------------- @@ -66,182 +66,182 @@ void bli_sdotv_zen_int10 cntx_t* restrict cntx ) { - const dim_t n_elem_per_reg = 8; - - dim_t i; - - float* restrict x0; - float* restrict y0; - - float rho0 = 0.0; - - __m256 xv[10]; - __m256 yv[10]; - v8sf_t rhov[10]; - - // If the vector dimension is zero, or if alpha is zero, return early. - if ( bli_zero_dim1( n ) ) - { - PASTEMAC(s,set0s)( *rho ); - return; - } - - // Initialize local pointers. - x0 = x; - y0 = y; - - PASTEMAC(s,set0s)( rho0 ); - - if ( incx == 1 && incy == 1 ) - { - rhov[0].v = _mm256_setzero_ps(); - rhov[1].v = _mm256_setzero_ps(); - rhov[2].v = _mm256_setzero_ps(); - rhov[3].v = _mm256_setzero_ps(); - rhov[4].v = _mm256_setzero_ps(); - rhov[5].v = _mm256_setzero_ps(); - rhov[6].v = _mm256_setzero_ps(); - rhov[7].v = _mm256_setzero_ps(); - rhov[8].v = _mm256_setzero_ps(); - rhov[9].v = _mm256_setzero_ps(); - - for ( i = 0 ; (i + 79) < n; i += 80 ) - { - // 80 elements will be processed per loop; 10 FMAs will run per loop. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); - yv[5] = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); - yv[6] = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); - yv[7] = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); - yv[8] = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); - yv[9] = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); - rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); - rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); - rhov[5].v = _mm256_fmadd_ps( xv[5], yv[5], rhov[5].v ); - rhov[6].v = _mm256_fmadd_ps( xv[6], yv[6], rhov[6].v ); - rhov[7].v = _mm256_fmadd_ps( xv[7], yv[7], rhov[7].v ); - rhov[8].v = _mm256_fmadd_ps( xv[8], yv[8], rhov[8].v ); - rhov[9].v = _mm256_fmadd_ps( xv[9], yv[9], rhov[9].v ); - - x0 += 10*n_elem_per_reg; - y0 += 10*n_elem_per_reg; - } - - rhov[0].v += rhov[5].v; - rhov[1].v += rhov[6].v; - rhov[2].v += rhov[7].v; - rhov[3].v += rhov[8].v; - rhov[4].v += rhov[9].v; - - for ( ; (i + 39) < n; i += 40 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); - rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); - rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); - - x0 += 5*n_elem_per_reg; - y0 += 5*n_elem_per_reg; - } - - rhov[0].v += rhov[2].v; - rhov[1].v += rhov[3].v; - rhov[0].v += rhov[4].v; - - for ( ; (i + 15) < n; i += 16 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - - x0 += 2*n_elem_per_reg; - y0 += 2*n_elem_per_reg; - } - - rhov[0].v += rhov[1].v; - - for ( ; (i + 7) < n; i += 8 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); - - x0 += 1*n_elem_per_reg; - y0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - rho0 += (*x0) * (*y0); - x0 += 1; - y0 += 1; - } - - rho0 += rhov[0].f[0] + rhov[0].f[1] + - rhov[0].f[2] + rhov[0].f[3] + - rhov[0].f[4] + rhov[0].f[5] + - rhov[0].f[6] + rhov[0].f[7]; - - // Issue vzeroupper instruction to clear upper lanes of ymm registers. - // This avoids a performance penalty caused by false dependencies when - // transitioning from from AVX to SSE instructions (which may occur - // later, especially if BLIS is compiled with -mfpmath=sse). - _mm256_zeroupper(); - } - else - { - for ( i = 0; i < n; ++i ) - { - const float x0c = *x0; - const float y0c = *y0; - - rho0 += x0c * y0c; - - x0 += incx; - y0 += incy; - } - } - - // Copy the final result into the output variable. - PASTEMAC(s,copys)( rho0, *rho ); + const dim_t n_elem_per_reg = 8; + + dim_t i; + + float* restrict x0; + float* restrict y0; + + float rho0 = 0.0; + + __m256 xv[10]; + __m256 yv[10]; + v8sf_t rhov[10]; + + // If the vector dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim1( n ) ) + { + PASTEMAC(s,set0s)( *rho ); + return; + } + + // Initialize local pointers. + x0 = x; + y0 = y; + + PASTEMAC(s,set0s)( rho0 ); + + if ( incx == 1 && incy == 1 ) + { + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + rhov[4].v = _mm256_setzero_ps(); + rhov[5].v = _mm256_setzero_ps(); + rhov[6].v = _mm256_setzero_ps(); + rhov[7].v = _mm256_setzero_ps(); + rhov[8].v = _mm256_setzero_ps(); + rhov[9].v = _mm256_setzero_ps(); + + for (i=0 ; (i + 79) < n; i += 80 ) + { + // 80 elements will be processed per loop; 10 FMAs will run per loop. + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + yv[5] = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); + yv[6] = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); + yv[7] = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); + yv[8] = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); + yv[9] = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[5], yv[5], rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[6], yv[6], rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[7], yv[7], rhov[7].v ); + rhov[8].v = _mm256_fmadd_ps( xv[8], yv[8], rhov[8].v ); + rhov[9].v = _mm256_fmadd_ps( xv[9], yv[9], rhov[9].v ); + + x0 += 10*n_elem_per_reg; + y0 += 10*n_elem_per_reg; + } + + rhov[0].v += rhov[5].v; + rhov[1].v += rhov[6].v; + rhov[2].v += rhov[7].v; + rhov[3].v += rhov[8].v; + rhov[4].v += rhov[9].v; + + for ( ; (i + 39) < n; i += 40 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); + + x0 += 5*n_elem_per_reg; + y0 += 5*n_elem_per_reg; + } + + rhov[0].v += rhov[2].v; + rhov[1].v += rhov[3].v; + rhov[0].v += rhov[4].v; + + for ( ; (i + 15) < n; i += 16 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + rhov[0].v += rhov[1].v; + + for ( ; (i + 7) < n; i += 8 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + rho0 += (*x0) * (*y0); + x0 += 1; + y0 += 1; + } + + rho0 += rhov[0].f[0] + rhov[0].f[1] + + rhov[0].f[2] + rhov[0].f[3] + + rhov[0].f[4] + rhov[0].f[5] + + rhov[0].f[6] + rhov[0].f[7]; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + } + else + { + for ( i = 0; i < n; ++i ) + { + const float x0c = *x0; + const float y0c = *y0; + + rho0 += x0c * y0c; + + x0 += incx; + y0 += incy; + } + } + + // Copy the final result into the output variable. + PASTEMAC(s,copys)( rho0, *rho ); } // ----------------------------------------------------------------------------- @@ -257,202 +257,803 @@ void bli_ddotv_zen_int10 cntx_t* restrict cntx ) { - const dim_t n_elem_per_reg = 4; - - dim_t i; - - double* restrict x0; - double* restrict y0; - - double rho0 = 0.0; - - __m256d xv[10]; - __m256d yv[10]; - v4df_t rhov[10]; - - // If the vector dimension is zero, or if alpha is zero, return early. - if ( bli_zero_dim1( n ) ) - { - PASTEMAC(d,set0s)( *rho ); - return; - } - - // Initialize local pointers. - x0 = x; - y0 = y; - - PASTEMAC(d,set0s)( rho0 ); - - if ( incx == 1 && incy == 1 ) - { - rhov[0].v = _mm256_setzero_pd(); - rhov[1].v = _mm256_setzero_pd(); - rhov[2].v = _mm256_setzero_pd(); - rhov[3].v = _mm256_setzero_pd(); - rhov[4].v = _mm256_setzero_pd(); - rhov[5].v = _mm256_setzero_pd(); - rhov[6].v = _mm256_setzero_pd(); - rhov[7].v = _mm256_setzero_pd(); - rhov[8].v = _mm256_setzero_pd(); - rhov[9].v = _mm256_setzero_pd(); - - for ( i = 0; (i + 39) < n; i += 40 ) - { - // 80 elements will be processed per loop; 10 FMAs will run per loop. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); - yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); - yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); - yv[8] = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); - yv[9] = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); - rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); - rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); - rhov[5].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[5].v ); - rhov[6].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[6].v ); - rhov[7].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[7].v ); - rhov[8].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[8].v ); - rhov[9].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[9].v ); - - x0 += 10*n_elem_per_reg; - y0 += 10*n_elem_per_reg; - } - - rhov[0].v += rhov[5].v; - rhov[1].v += rhov[6].v; - rhov[2].v += rhov[7].v; - rhov[3].v += rhov[8].v; - rhov[4].v += rhov[9].v; - - for ( ; (i + 19) < n; i += 20 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); - rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); - rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); - - x0 += 5*n_elem_per_reg; - y0 += 5*n_elem_per_reg; - } - - rhov[0].v += rhov[4].v; - - for ( ; (i + 15) < n; i += 16 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); - rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); - - x0 += 4*n_elem_per_reg; - y0 += 4*n_elem_per_reg; - } - - rhov[0].v += rhov[2].v; - rhov[1].v += rhov[3].v; - - for ( ; (i + 7) < n; i += 8 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - - x0 += 2*n_elem_per_reg; - y0 += 2*n_elem_per_reg; - } - - rhov[0].v += rhov[1].v; - - for ( ; (i + 3) < n; i += 4 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); - - x0 += 1*n_elem_per_reg; - y0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - rho0 += (*x0) * (*y0); - - x0 += 1; - y0 += 1; - } - - // Manually add the results from above to finish the sum. - rho0 += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3]; - - // Issue vzeroupper instruction to clear upper lanes of ymm registers. - // This avoids a performance penalty caused by false dependencies when - // transitioning from from AVX to SSE instructions (which may occur - // later, especially if BLIS is compiled with -mfpmath=sse). - _mm256_zeroupper(); - } - else - { - for ( i = 0; i < n; ++i ) - { - const double x0c = *x0; - const double y0c = *y0; - - rho0 += x0c * y0c; - - x0 += incx; - y0 += incy; - } - } - - // Copy the final result into the output variable. - PASTEMAC(d,copys)( rho0, *rho ); + const dim_t n_elem_per_reg = 4; + + dim_t i; + + double* restrict x0; + double* restrict y0; + + double rho0 = 0.0; + + __m256d xv[10]; + __m256d yv[10]; + v4df_t rhov[10]; + + // If the vector dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim1( n ) ) + { + PASTEMAC(d,set0s)( *rho ); + return; + } + + // Initialize local pointers. + x0 = x; + y0 = y; + + PASTEMAC(d,set0s)( rho0 ); + + if ( incx == 1 && incy == 1 ) + { + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + rhov[8].v = _mm256_setzero_pd(); + rhov[9].v = _mm256_setzero_pd(); + + for ( i = 0; (i + 39) < n; i += 40 ) + { + // 80 elements will be processed per loop; 10 FMAs will run per loop. + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); + yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); + yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); + yv[8] = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); + yv[9] = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[7].v ); + rhov[8].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[8].v ); + rhov[9].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[9].v ); + + x0 += 10*n_elem_per_reg; + y0 += 10*n_elem_per_reg; + } + + rhov[0].v += rhov[5].v; + rhov[1].v += rhov[6].v; + rhov[2].v += rhov[7].v; + rhov[3].v += rhov[8].v; + rhov[4].v += rhov[9].v; + + for ( ; (i + 19) < n; i += 20 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); + + x0 += 5*n_elem_per_reg; + y0 += 5*n_elem_per_reg; + } + + rhov[0].v += rhov[4].v; + + for ( ; (i + 15) < n; i += 16 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); + + x0 += 4*n_elem_per_reg; + y0 += 4*n_elem_per_reg; + } + + rhov[0].v += rhov[2].v; + rhov[1].v += rhov[3].v; + + for ( ; (i + 7) < n; i += 8 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + rhov[0].v += rhov[1].v; + + for ( ; (i + 3) < n; i += 4 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + rho0 += (*x0) * (*y0); + + x0 += 1; + y0 += 1; + } + + // Manually add the results from above to finish the sum. + rho0 += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3]; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + } + else + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double y0c = *y0; + + rho0 += x0c * y0c; + + x0 += incx; + y0 += incy; + } + } + + // Copy the final result into the output variable. + PASTEMAC(d,copys)( rho0, *rho ); } +// ----------------------------------------------------------------------------- + + +void bli_cdotv_zen_int5 + ( + conj_t conjx, + conj_t conjy, + dim_t n, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + scomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 8; + + dim_t i; + + float* restrict x0; + float* restrict y0; + + scomplex rho0 ; + rho0.real = 0.0; + rho0.imag = 0.0; + + __m256 xv[5]; + __m256 yv[5]; + __m256 zv[5]; + v8sf_t rhov[10]; + + conj_t conjx_use = conjx; + /* If y must be conjugated, we do so indirectly by first toggling the + effective conjugation of x and then conjugating the resulting dot + product. */ + if ( bli_is_conj( conjy ) ) + bli_toggle_conj( &conjx_use ); + + // If the vector dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim1( n ) ) + { + PASTEMAC(c,set0s)( *rho ); + return; + } + + // Initialize local pointers. + x0 = (float*) x; + y0 = (float*) y; + + PASTEMAC(c,set0s)( rho0 ); + + /* + * Computing dot product of 2 complex vectors + * dotProd = Σ (xr + i xi) * (yr - i yi) + * dotProdReal = xr1 * yr1 + xi1 * yi1 + xr2 * yr2 + xi2 * yi2 + .... + xrn * yrn + xin * yin + * dotProdImag = -( xr1 * yi1 - xi1 * yr1 + xr2 * yi2 - xi2 * yr2 + .... + xrn * yin - xin * yrn) + * Product of vectors are carried out using intrinsics code _mm256_fmadd_ps with 256bit register + * Each element of 256bit register is added/subtracted based on element position + */ + if ( incx == 1 && incy == 1 ) + { + /* Set of registers used to compute real value of dot product */ + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + rhov[4].v = _mm256_setzero_ps(); + /* set of registers used to compute imag value of dot product */ + rhov[5].v = _mm256_setzero_ps(); + rhov[6].v = _mm256_setzero_ps(); + rhov[7].v = _mm256_setzero_ps(); + rhov[8].v = _mm256_setzero_ps(); + rhov[9].v = _mm256_setzero_ps(); + + /* + * Compute of 1-256bit register + * xv = xr1 xi1 xr2 xi2 xr3 xi3 xr4 xi4 + * yv = yr1 yi1 yr2 yi2 yr3 yi3 yr4 yi4 + * zv = yi1 yr1 yi2 yr2 yi3 yr3 yi4 yr4 + * rhov0(real) = xr1*yr1, xi1*yi1, xr2*yr2, xi2*yi2, xr3*yr3, xi3*yi3, xr4*yr4, xi4*yi4 + * rhov5(imag) = xr1*yi1, xi1*yr1, xr2*yi2, xi2*yr2, xr3*yi3, xi3*yr3, xr4*yi4, xi4*yr4 + */ + for (i=0 ; (i + 19) < n; i += 20 ) + { + // 20 elements will be processed per loop; 10 FMAs will run per loop. + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + + /* Permute step is swapping real and imaginary values. + yv = yr1 yi1 yr2 yi2 yr3 yi3 yr4 yi4 + zv = yi1 yr1 yi2 yr2 yi3 yr3 yi4 yr4 + zv is required to compute imaginary values */ + zv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + zv[1] = _mm256_permute_ps( yv[1], 0xB1 ); + zv[2] = _mm256_permute_ps( yv[2], 0xB1 ); + zv[3] = _mm256_permute_ps( yv[3], 0xB1 ); + zv[4] = _mm256_permute_ps( yv[4], 0xB1 ); + + /* Compute real values */ + rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); + + /* Compute imaginary values*/ + rhov[5].v = _mm256_fmadd_ps( xv[0], zv[0], rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[1], zv[1], rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[2], zv[2], rhov[7].v ); + rhov[8].v = _mm256_fmadd_ps( xv[3], zv[3], rhov[8].v ); + rhov[9].v = _mm256_fmadd_ps( xv[4], zv[4], rhov[9].v ); + + x0 += 5 * n_elem_per_reg; + y0 += 5 * n_elem_per_reg; + } + + /* Real value computation: rhov[0] & rhov[1] used in below + for loops hence adding up other register values */ + rhov[0].v += rhov[2].v; + rhov[1].v += rhov[3].v; + rhov[0].v += rhov[4].v; + + /* Imag value computation: rhov[5] & rhov[6] used in below + for loops hence adding up other register values */ + rhov[5].v += rhov[7].v; + rhov[6].v += rhov[8].v; + rhov[5].v += rhov[9].v; + + for ( ; (i + 7) < n; i += 8 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + /* Permute step is swapping real and imaginary values. + yv = yr1 yi1 yr2 yi2 yr3 yi3 yr4 yi4 + zv = yi1 yr1 yi2 yr2 yi3 yr3 yi4 yr4 + zv is required to compute imaginary values */ + zv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + zv[1] = _mm256_permute_ps( yv[1], 0xB1 ); + + /* Compute real values */ + rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); + + /* Compute imaginary values*/ + rhov[5].v = _mm256_fmadd_ps( xv[0], zv[0], rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[1], zv[1], rhov[6].v ); + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + /*Accumalte real values in to rhov[0]*/ + rhov[0].v += rhov[1].v; + + /*Accumalte imaginary values in to rhov[5]*/ + rhov[5].v += rhov[6].v; + + for ( ; (i + 3) < n; i += 4 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + /* Permute step is swapping real and imaginary values. + yv = yr1 yi1 yr2 yi2 yr3 yi3 yr4 yi4 + zv = yi1 yr1 yi2 yr2 yi3 yr3 yi4 yr4 + zv is required to compute imaginary values */ + zv[0] = _mm256_permute_ps( yv[0], 0xB1 ); + + /* Compute real values */ + rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); + + /* Compute imaginary values*/ + rhov[5].v = _mm256_fmadd_ps( xv[0], zv[0], rhov[5].v ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + + + /* Residual values are calculated here + rho := conjx(x)^T * conjy(y) + n = 1, When no conjugate for x or y vector + rho = conj(xr + xi) * conj(yr + yi) + rho = (xr - xi) * (yr -yi) + rho.real = xr*yr + xi*yi + rho.imag = -(xi*yr - xr *yi) + -ve sign of imaginary value is taken care at the end of function + When vector x/y to be conjugated, imaginary values(xi and yi) to be negated + */ + if ( !bli_is_conj(conjx_use) ) + { + for ( ; (i + 0) < n; i += 1 ) + { + rho0.real += (*x0) * (*y0) - (*(x0+1)) * (*(y0+1)); + rho0.imag += (*x0) * (*(y0+1)) + (*(x0+1)) * (*y0); + x0 += 2; + y0 += 2; + } + } + else + { + for ( ; (i + 0) < n; i += 1 ) + { + rho0.real += (*x0) * (*y0) + (*(x0+1)) * (*(y0+1)); + rho0.imag += (*x0) * (*(y0+1)) - (*(x0+1)) * (*y0); + x0 += 2; + y0 += 2; + } + } + + /* Find dot product by summing up all elements */ + if ( !bli_is_conj(conjx_use) ) + { + rho0.real += rhov[0].f[0] - rhov[0].f[1] + + rhov[0].f[2] - rhov[0].f[3] + + rhov[0].f[4] - rhov[0].f[5] + + rhov[0].f[6] - rhov[0].f[7]; + + rho0.imag += rhov[5].f[0] + rhov[5].f[1] + + rhov[5].f[2] + rhov[5].f[3] + + rhov[5].f[4] + rhov[5].f[5] + + rhov[5].f[6] + rhov[5].f[7]; + } + else + { + rho0.real += rhov[0].f[0] + rhov[0].f[1] + + rhov[0].f[2] + rhov[0].f[3] + + rhov[0].f[4] + rhov[0].f[5] + + rhov[0].f[6] + rhov[0].f[7]; + + rho0.imag += rhov[5].f[0] - rhov[5].f[1] + + rhov[5].f[2] - rhov[5].f[3] + + rhov[5].f[4] - rhov[5].f[5] + + rhov[5].f[6] - rhov[5].f[7]; + } + + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjy) ) { + rho0.imag = -rho0.imag; + } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + } + else + { + /* rho := conjx(x)^T * conjy(y) + n = 1, When no conjugate for x or y vector + rho = conj(xr + xi) * conj(yr + yi) + rho = (xr - xi) * (yr -yi) + rho.real = xr*yr + xi*yi + rho.imag = -(xi*yr - xr *yi) + -ve sign of imaginary value is taken care at the end of function + When vector x/y to be conjugated, imaginary values(xi and yi) to be negated + */ + if ( !bli_is_conj(conjx_use) ) + { + for ( i = 0; i < n; ++i ) + { + const float x0c = *x0; + const float y0c = *y0; + + const float x1c = *( x0+1 ); + const float y1c = *( y0+1 ); + + rho0.real += x0c * y0c - x1c * y1c; + rho0.imag += x0c * y1c + x1c * y0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n; ++i ) + { + const float x0c = *x0; + const float y0c = *y0; + + const float x1c = *( x0+1 ); + const float y1c = *( y0+1 ); + + rho0.real += x0c * y0c + x1c * y1c; + rho0.imag += x0c * y1c - x1c * y0c; + + x0+= incx * 2; + y0+= incy * 2; + } + } + + /* Negate sign of imaginary value when vector y is conjugate */ + if( bli_is_conj(conjy) ) + rho0.imag = -rho0.imag; + } + + // Copy the final result into the output variable. + PASTEMAC(c,copys)( rho0, *rho ); +} + + +// ----------------------------------------------------------------------------- + +void bli_zdotv_zen_int5 + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 4; + + dim_t i; + + double* restrict x0; + double* restrict y0; + + dcomplex rho0 ; + + rho0.real = 0.0; + rho0.imag = 0.0; + + __m256d xv[5]; + __m256d yv[5]; + __m256d zv[5]; + v4df_t rhov[10]; + + conj_t conjx_use = conjx; + /* If y must be conjugated, we do so indirectly by first toggling the + effective conjugation of x and then conjugating the resulting dot + product. */ + if ( bli_is_conj( conjy ) ) + bli_toggle_conj( &conjx_use ); + + // If the vector dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim1( n ) ) + { + PASTEMAC(z,set0s)( *rho ); + return; + } + + // Initialize local pointers. + x0 = (double *) x; + y0 = (double *) y; + + PASTEMAC(z,set0s)( rho0 ); + + /* + * Computing dot product of 2 complex vectors + * dotProd = Σ (xr + i xi) * (yr - iyi) + * dotProdReal = xr1 * yr1 + xi1 * yi1 + xr2 * yr2 + xi2 * yi2 + .... + xrn * yrn + xin * yin + * dotProdImag = xi1 * yr1 - xr1 * yi1 + xi2 * yr2 - xr2 * yi2 + .... + xin * yrn - xrn * yin + * Product of vectors are carried out using intrinsics code _mm256_fmadd_ps with 256bit register + * Each element of 256bit register is added/subtracted based on element position + */ + if ( incx == 1 && incy == 1 ) + { + /* Set of registers used to compute real value of dot product */ + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + /* Set of registers used to compute real value of dot product */ + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + rhov[8].v = _mm256_setzero_pd(); + rhov[9].v = _mm256_setzero_pd(); + + /* + * Compute of 1-256bit register + * xv = xr1 xi1 xr2 xi2 + * yv = yr1 yi1 yr1 yi2 + * zv = yi1 yr1 yi1 yr2 + * rhov0(real) = xr1*yr1, xi1*yi1, xr2*yr2, xi2*yi2 + * rhov5(imag) = xr1*yi1, xi1*yr1, xr2*yi2, xi2*yr2 + */ + for ( i = 0; (i + 9) < n; i += 10 ) + { + // 10 elements will be processed per loop; 10 FMAs will run per loop. + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + + /* Permute step is swapping real and imaginary values. + yv = yr1 yi1 yr2 yi2 + zv = yi1 yr1 yi2 yr2 + zv is required to compute imaginary values */ + zv[0] = _mm256_permute_pd( yv[0], 5 ); + zv[1] = _mm256_permute_pd( yv[1], 5 ); + zv[2] = _mm256_permute_pd( yv[2], 5 ); + zv[3] = _mm256_permute_pd( yv[3], 5 ); + zv[4] = _mm256_permute_pd( yv[4], 5 ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[0], zv[0], rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[1], zv[1], rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[2], zv[2], rhov[7].v ); + rhov[8].v = _mm256_fmadd_pd( xv[3], zv[3], rhov[8].v ); + rhov[9].v = _mm256_fmadd_pd( xv[4], zv[4], rhov[9].v ); + + x0 += 5*n_elem_per_reg; + y0 += 5*n_elem_per_reg; + } + + /* Real value computation: rhov[0] & rhov[1] used in below + for loops hence adding up other register values */ + rhov[0].v += rhov[2].v; + rhov[1].v += rhov[3].v; + rhov[0].v += rhov[4].v; + + /* Imag value computation: rhov[5] & rhov[6] used in below + for loops hence adding up other register values */ + rhov[5].v += rhov[7].v; + rhov[6].v += rhov[8].v; + rhov[5].v += rhov[9].v; + + for ( ; (i + 3) < n; i += 4 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + /* Permute step is swapping real and imaginary values. + yv = yr1 yi1 yr2 yi2 + zv = yi1 yr1 yi2 yr2 + zv is required to compute imaginary values */ + zv[0] = _mm256_permute_pd( yv[0], 5 ); + zv[1] = _mm256_permute_pd( yv[1], 5 ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); + rhov[5].v = _mm256_fmadd_pd( xv[0], zv[0], rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[1], zv[1], rhov[6].v ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + } + + /*Accumulate real values*/ + rhov[0].v += rhov[1].v; + /*Accumulate imaginary values*/ + rhov[5].v += rhov[6].v; + + for ( ; (i + 3) < n; i += 2 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + /* Permute step is swapping real and imaginary values. + yv = yr1 yi1 yr2 yi2 + zv = yi1 yr1 yi2 yr2 + zv is required to compute imaginary values */ + zv[0] = _mm256_permute_pd( yv[0], 5 ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + rhov[5].v = _mm256_fmadd_pd( xv[0], zv[0], rhov[5].v ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + } + + /* Residual values are calculated here + rho := conjx(x)^T * conjy(y) + n = 1, When no conjugate for x or y vector + rho = conj(xr + xi) * conj(yr + yi) + rho = (xr - xi) * (yr -yi) + rho.real = xr*yr + xi*yi + rho.imag = -(xi*yr - xr *yi) + -ve sign of imaginary value is taken care at the end of function + When vector x/y to be conjugated, imaginary values(xi and yi) to be negated + */ + if ( !bli_is_conj(conjx_use) ) + { + for ( ; (i + 0) < n; i += 1 ) + { + rho0.real += ( *x0 ) * ( *y0 ) - ( *(x0+1) ) * ( *( y0+1 ) ); + rho0.imag += ( *x0 ) * ( *(y0+1) ) + ( *(x0+1) ) * ( *y0 ); + x0 += 2; + y0 += 2; + } + } + else + { + for ( ; (i + 0) < n; i += 1 ) + { + rho0.real += ( *x0 ) * ( *y0 ) + ( *( x0+1 ) ) * ( *( y0+1 ) ); + rho0.imag += ( *x0 ) * ( *( y0+1 ) ) - ( *( x0+1 ) ) * ( *y0 ); + x0 += 2; + y0 += 2; + } + } + + /* Find dot product by summing up all elements */ + if ( !bli_is_conj(conjx_use) ) + { + rho0.real += rhov[0].d[0] - rhov[0].d[1] + rhov[0].d[2] - rhov[0].d[3]; + rho0.imag += rhov[5].d[0] + rhov[5].d[1] + rhov[5].d[2] + rhov[5].d[3]; + } + else + { + rho0.real += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3]; + rho0.imag += rhov[5].d[0] - rhov[5].d[1] + rhov[5].d[2] - rhov[5].d[3]; + } + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjy) ) + rho0.imag = -rho0.imag; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + } + else + { + /* rho := conjx(x)^T * conjy(y) + n = 1, When no conjugate for x or y vector + rho = conj(xr + xi) * conj(yr + yi) + rho = (xr - xi) * (yr -yi) + rho.real = xr*yr + xi*yi + rho.imag = -(xi*yr - xr *yi) + -ve sign of imaginary value is taken care at the end of function + When vector x/y to be conjugated, imaginary values(xi and yi) to be negated + */ + if ( !bli_is_conj(conjx_use) ) + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double y0c = *y0; + + const double x1c = *( x0 + 1 ); + const double y1c = *( y0 + 1 ); + + rho0.real += x0c * y0c - x1c * y1c; + rho0.imag += x0c * y1c + x1c * y0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double y0c = *y0; + + const double x1c = *( x0 + 1 ); + const double y1c = *( y0 + 1 ); + + rho0.real += x0c * y0c + x1c * y1c; + rho0.imag += x0c * y1c - x1c * y0c; + + x0 += incx * 2; + y0 += incy * 2; + } + } + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjy) ) + rho0.imag = -rho0.imag; + } + + // Copy the final result into the output variable. + PASTEMAC(z,copys)( rho0, *rho ); +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 161bcef1aa..df9e85eb40 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,6 +53,8 @@ AXPYV_KER_PROT( double, d, axpyv_zen_int ) // axpyv (intrinsics unrolled x10) AXPYV_KER_PROT( float, s, axpyv_zen_int10 ) AXPYV_KER_PROT( double, d, axpyv_zen_int10 ) +AXPYV_KER_PROT( scomplex, c, axpyv_zen_int5 ) +AXPYV_KER_PROT( dcomplex, z, axpyv_zen_int5 ) // dotv (intrinsics) DOTV_KER_PROT( float, s, dotv_zen_int ) @@ -61,6 +63,8 @@ DOTV_KER_PROT( double, d, dotv_zen_int ) // dotv (intrinsics, unrolled x10) DOTV_KER_PROT( float, s, dotv_zen_int10 ) DOTV_KER_PROT( double, d, dotv_zen_int10 ) +DOTV_KER_PROT( scomplex, c, dotv_zen_int5 ) +DOTV_KER_PROT( dcomplex, z, dotv_zen_int5 ) // dotxv (intrinsics) DOTXV_KER_PROT( float, s, dotxv_zen_int ) @@ -115,7 +119,7 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x8 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x8 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x4 ) -GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x4 ) diff --git a/test/test_axpyv.c b/test/test_axpyv.c index 44a0d2d746..61ddb9c591 100644 --- a/test/test_axpyv.c +++ b/test/test_axpyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,168 +42,194 @@ // n alpha x incx y incy //void daxpyv_( int*, double*, double*, int*, double*, int* ); -//#define PRINT +// #define PRINT int main( int argc, char** argv ) { - obj_t x, y; - obj_t y_save; - obj_t alpha; - dim_t n; - dim_t p; - dim_t p_begin, p_end, p_inc; - int n_input; - num_t dt_x, dt_y; - num_t dt_alpha; - int r, n_repeats; - num_t dt; - - double dtime; - double dtime_save; - double gflops; - - bli_init(); - - n_repeats = 3; + obj_t x, y; + obj_t y_save; + obj_t alpha; + dim_t n; + dim_t p; + dim_t p_begin, p_end, p_inc; + int n_input; + num_t dt_x, dt_y; + num_t dt_alpha; + int r, n_repeats; + num_t dt; + + double dtime; + double dtime_save; + double gflops; + + bli_init(); + + n_repeats = 1; #ifndef PRINT - p_begin = 40; - p_end = 4000; - p_inc = 40; + p_begin = 10; + p_end = 100; + p_inc = 10; - n_input = -1; + n_input = -1; #else - p_begin = 16; - p_end = 16; - p_inc = 1; + p_begin = 16; + p_end = 16; + p_inc = 1; - n_input = 15; + n_input = 15; #endif #if 1 - dt = BLIS_FLOAT; - //dt = BLIS_DOUBLE; + dt = BLIS_FLOAT; + //dt = BLIS_DOUBLE; #else - //dt = BLIS_SCOMPLEX; - dt = BLIS_DCOMPLEX; + // dt = BLIS_SCOMPLEX; + // dt = BLIS_DCOMPLEX; #endif + dt_x = dt_y = dt_alpha = dt; - dt_x = dt_y = dt_alpha = dt; - - // Begin with initializing the last entry to zero so that - // matlab allocates space for the entire array once up-front. - for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; #ifdef BLIS - printf( "data_axpyv_blis" ); + printf( "data_axpyv_blis" ); #else - printf( "data_axpyv_%s", BLAS ); + printf( "data_axpyv_%s", BLAS ); #endif - printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin)/p_inc + 1, - ( unsigned long )0, 0.0 ); + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, 0.0 ); - //for ( p = p_begin; p <= p_end; p += p_inc ) - for ( p = p_end; p_begin <= p; p -= p_inc ) - { + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) + { - if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); - else n = ( dim_t ) n_input; + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; - bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha ); - bli_obj_create( dt_x, n, 1, 0, 0, &x ); - bli_obj_create( dt_y, n, 1, 0, 0, &y ); - bli_obj_create( dt_y, n, 1, 0, 0, &y_save ); + bli_obj_create( dt_x, n, 1, 0, 0, &x ); + bli_obj_create( dt_y, n, 1, 0, 0, &y ); + bli_obj_create( dt_y, n, 1, 0, 0, &y_save ); - bli_randm( &x ); - bli_randm( &y ); + bli_randm( &x ); + bli_randm( &y ); - bli_setsc( (2.0/1.0), 0.0, &alpha ); + bli_setsc( (2.0/1.0), 0.0, &alpha ); - bli_copym( &y, &y_save ); + bli_copym( &y, &y_save ); - dtime_save = 1.0e9; + dtime_save = 1.0e9; - for ( r = 0; r < n_repeats; ++r ) - { - bli_copym( &y_save, &y ); + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &y_save, &y ); - dtime = bli_clock(); + dtime = bli_clock(); #ifdef PRINT - bli_printm( "alpha", &alpha, "%4.1f", "" ); - bli_printm( "x", &x, "%4.1f", "" ); - bli_printm( "y", &y, "%4.1f", "" ); + bli_printm( "alpha", &alpha, "%4.1f", "" ); + bli_printm( "x", &x, "%4.1f", "" ); + bli_printm( "y", &y, "%4.1f", "" ); #endif #ifdef BLIS - bli_axpyv( &alpha, - &x, - &y ); + bli_axpyv( &alpha, + &x, + &y ); #else - if ( bli_is_float( dt ) ) - { - f77_int nn = bli_obj_length( &x ); - f77_int incx = bli_obj_vector_inc( &x ); - f77_int incy = bli_obj_vector_inc( &y ); - float* alphap = bli_obj_buffer( &alpha ); - float* xp = bli_obj_buffer( &x ); - float* yp = bli_obj_buffer( &y ); - - saxpy_( &nn, - alphap, - xp, &incx, - yp, &incy ); - - - } - else if ( bli_is_double( dt ) ) - { - - f77_int nn = bli_obj_length( &x ); - f77_int incx = bli_obj_vector_inc( &x ); - f77_int incy = bli_obj_vector_inc( &y ); - double* alphap = bli_obj_buffer( &alpha ); - double* xp = bli_obj_buffer( &x ); - double* yp = bli_obj_buffer( &y ); - - daxpy_( &nn, - alphap, - xp, &incx, - yp, &incy ); - } + if ( bli_is_float( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + float* alphap = bli_obj_buffer( &alpha ); + float* xp = bli_obj_buffer( &x ); + float* yp = bli_obj_buffer( &y ); + + saxpy_( &nn, + alphap, + xp, &incx, + yp, &incy ); + + } + else if ( bli_is_double( dt ) ) + { + + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + double* alphap = bli_obj_buffer( &alpha ); + double* xp = bli_obj_buffer( &x ); + double* yp = bli_obj_buffer( &y ); + + daxpy_( &nn, + alphap, + xp, &incx, + yp, &incy ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + void* alphap = bli_obj_buffer( &alpha ); + void* xp = bli_obj_buffer( &x ); + void* yp = bli_obj_buffer( &y ); + + caxpy_( &nn, + (scomplex*)alphap, + (scomplex*)xp, &incx, + (scomplex*)yp, &incy ); + } + else if ( bli_is_dcomplex( dt )) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + void* alphap = bli_obj_buffer( &alpha ); + void* xp = bli_obj_buffer( &x ); + void* yp = bli_obj_buffer( &y ); + + zaxpy_( &nn, + (dcomplex*)alphap, + (dcomplex*)xp, &incx, + (dcomplex*)yp, &incy ); + } #endif #ifdef PRINT - bli_printm( "y after", &y, "%4.1f", "" ); - exit(1); + bli_printm( "y after", &y, "%4.1f", "" ); + exit(1); #endif + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } - dtime_save = bli_clock_min_diff( dtime_save, dtime ); - } - - gflops = ( 2.0 * n ) / ( dtime_save * 1.0e9 ); + gflops = ( 2.0 * n ) / ( dtime_save * 1.0e9 ); + if ( bli_obj_is_complex( &x ) ) gflops *= 4.0; #ifdef BLIS - printf( "data_axpyv_blis" ); + printf( "data_axpyv_blis" ); #else - printf( "data_axpyv_%s", BLAS ); + printf( "data_axpyv_%s", BLAS ); #endif - printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin)/p_inc + 1, - ( unsigned long )n, gflops ); + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )n, gflops ); - bli_obj_free( &alpha ); + bli_obj_free( &alpha ); - bli_obj_free( &x ); - bli_obj_free( &y ); - bli_obj_free( &y_save ); - } + bli_obj_free( &x ); + bli_obj_free( &y ); + bli_obj_free( &y_save ); + } - bli_finalize(); + bli_finalize(); - return 0; + return 0; }