@@ -294,16 +294,45 @@ true
294294"""
295295@inline mul! (C:: AbstractMatrix , A:: AbstractVecOrMat , B:: AbstractVecOrMat , α:: Number , β:: Number ) = _mul! (C, A, B, α, β)
296296# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
297- @inline _mul! (C:: AbstractMatrix , A:: AbstractVecOrMat , B:: AbstractVecOrMat , α:: Number , β:: Number ) =
297+ module BlasFlag
298+ @enum BlasFunction SYRK HERK GEMM SYMM HEMM NONE
299+ const SyrkHerkGemm = Union{Val{SYRK}, Val{HERK}, Val{GEMM}}
300+ const SymmHemmGeneric = Union{Val{SYMM}, Val{HEMM}, Val{NONE}}
301+ end
302+ @inline function _mul! (C:: AbstractMatrix , A:: AbstractVecOrMat , B:: AbstractVecOrMat , α:: Number , β:: Number )
303+ tA = wrapper_char (A)
304+ tB = wrapper_char (B)
305+ tA_uc = uppercase (tA)
306+ tB_uc = uppercase (tB)
307+ isntc = wrapper_char_NTC (A) & wrapper_char_NTC (B)
308+ blasfn = if isntc
309+ if (tA_uc == ' T' && tB_uc == ' N' ) || (tA_uc == ' N' && tB_uc == ' T' )
310+ BlasFlag. SYRK
311+ elseif (tA_uc == ' C' && tB_uc == ' N' ) || (tA_uc == ' N' && tB_uc == ' C' )
312+ BlasFlag. HERK
313+ else isntc
314+ BlasFlag. GEMM
315+ end
316+ else
317+ if (tA_uc == ' S' && tB_uc == ' N' ) || (tA_uc == ' N' && tB_uc == ' S' )
318+ BlasFlag. SYMM
319+ elseif (tA_uc == ' H' && tB_uc == ' N' ) || (tA_uc == ' N' && tB_uc == ' H' )
320+ BlasFlag. HEMM
321+ else
322+ BlasFlag. NONE
323+ end
324+ end
325+
298326 generic_matmatmul_wrapper! (
299327 C,
300- wrapper_char (A) ,
301- wrapper_char (B) ,
328+ tA ,
329+ tB ,
302330 _unwrap (A),
303331 _unwrap (B),
304332 α, β,
305- Val (wrapper_char_NTC (A) & wrapper_char_NTC (B))
333+ Val (blasfn),
306334 )
335+ end
307336
308337# this indirection allows is to specialize on the types of the wrappers of A and B to some extent,
309338# even though the wrappers are stripped off in mul!
408437
409438# THE one big BLAS dispatch. This is split into two methods to improve latency
410439Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
411- α:: Number , β:: Number , :: Val{true} ) where {T<: BlasFloat }
440+ α:: Number , β:: Number , val :: BlasFlag.SyrkHerkGemm ) where {T<: BlasFloat }
412441 mA, nA = lapack_size (tA, A)
413442 mB, nB = lapack_size (tB, B)
414443 if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
@@ -418,24 +447,31 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
418447 return _rmul_or_fill! (C, β)
419448 end
420449 matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
421- # We convert the chars to uppercase to potentially unwrap a WrapperChar,
422- # and extract the char corresponding to the wrapper type
423- tA_uc, tB_uc = uppercase (tA), uppercase (tB)
424- # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
425- if tA_uc == ' T' && tB_uc == ' N' && A === B
426- return syrk_wrapper! (C, ' T' , A, α, β)
427- elseif tA_uc == ' N' && tB_uc == ' T' && A === B
428- return syrk_wrapper! (C, ' N' , A, α, β)
429- elseif tA_uc == ' C' && tB_uc == ' N' && A === B
430- return herk_wrapper! (C, ' C' , A, α, β)
431- elseif tA_uc == ' N' && tB_uc == ' C' && A === B
432- return herk_wrapper! (C, ' N' , A, α, β)
450+ _syrk_herk_gemm_wrapper! (C, tA, tB, A, B, α, β, val)
451+ return C
452+ end
453+ Base. @constprop :aggressive function _syrk_herk_gemm_wrapper! (C, tA, tB, A, B, α, β, :: Val{BlasFlag.SYRK} )
454+ if A === B
455+ tA_uc = uppercase (tA) # potentially strip a WrapperChar
456+ return syrk_wrapper! (C, tA_uc, A, α, β)
433457 else
434458 return gemm_wrapper! (C, tA, tB, A, B, α, β)
435459 end
436460end
461+ Base. @constprop :aggressive function _syrk_herk_gemm_wrapper! (C, tA, tB, A, B, α, β, :: Val{BlasFlag.HERK} )
462+ if A === B
463+ tA_uc = uppercase (tA) # potentially strip a WrapperChar
464+ return herk_wrapper! (C, tA_uc, A, α, β)
465+ else
466+ return gemm_wrapper! (C, tA, tB, A, B, α, β)
467+ end
468+ end
469+ Base. @constprop :aggressive function _syrk_herk_gemm_wrapper! (C, tA, tB, A, B, α, β, :: Val{BlasFlag.GEMM} )
470+ return gemm_wrapper! (C, tA, tB, A, B, α, β)
471+ end
472+ _valtypeparam (v:: Val{T} ) where {T} = T
437473Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
438- α:: Number , β:: Number , :: Val{false} ) where {T<: BlasFloat }
474+ α:: Number , β:: Number , val :: BlasFlag.SymmHemmGeneric ) where {T<: BlasFloat }
439475 mA, nA = lapack_size (tA, A)
440476 mB, nB = lapack_size (tB, B)
441477 if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
@@ -445,23 +481,48 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
445481 return _rmul_or_fill! (C, β)
446482 end
447483 matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
448- # We convert the chars to uppercase to potentially unwrap a WrapperChar,
449- # and extract the char corresponding to the wrapper type
450- tA_uc, tB_uc = uppercase (tA), uppercase (tB)
451484 alpha, beta = promote (α, β, zero (T))
452- if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
453- if tA_uc == ' S' && tB_uc == ' N'
454- return BLAS. symm! (' L' , tA == ' S' ? ' U' : ' L' , alpha, A, B, beta, C)
455- elseif tA_uc == ' N' && tB_uc == ' S'
456- return BLAS. symm! (' R' , tB == ' S' ? ' U' : ' L' , alpha, B, A, beta, C)
457- elseif tA_uc == ' H' && tB_uc == ' N'
458- return BLAS. hemm! (' L' , tA == ' H' ? ' U' : ' L' , alpha, A, B, beta, C)
459- elseif tA_uc == ' N' && tB_uc == ' H'
460- return BLAS. hemm! (' R' , tB == ' H' ? ' U' : ' L' , alpha, B, A, beta, C)
461- end
485+ blasfn = _valtypeparam (val)
486+ if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && blasfn ∈ (BlasFlag. SYMM, BlasFlag. HEMM)
487+ _blasfn = blasfn
488+ αβ = (alpha, beta)
489+ else
490+ _blasfn = BlasFlag. NONE
491+ αβ = (α, β)
462492 end
463- return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (α, β))
493+ _symm_hemm_generic! (C, tA, tB, A, B, αβ... , Val (_blasfn))
494+ return C
464495end
496+ Base. @constprop :aggressive function _lrchar_ulchar (tA, tB)
497+ if uppercase (tA) == ' N'
498+ lrchar = ' R'
499+ ulchar = isuppercase (tB) ? ' U' : ' L'
500+ else
501+ lrchar = ' L'
502+ ulchar = isuppercase (tA) ? ' U' : ' L'
503+ end
504+ return lrchar, ulchar
505+ end
506+ function _symm_hemm_generic! (C, tA, tB, A, B, alpha, beta, :: Val{BlasFlag.SYMM} )
507+ lrchar, ulchar = _lrchar_ulchar (tA, tB)
508+ if lrchar == ' L'
509+ BLAS. symm! (lrchar, ulchar, alpha, A, B, beta, C)
510+ else
511+ BLAS. symm! (lrchar, ulchar, alpha, B, A, beta, C)
512+ end
513+ end
514+ function _symm_hemm_generic! (C, tA, tB, A, B, alpha, beta, :: Val{BlasFlag.HEMM} )
515+ lrchar, ulchar = _lrchar_ulchar (tA, tB)
516+ if lrchar == ' L'
517+ BLAS. hemm! (lrchar, ulchar, alpha, A, B, beta, C)
518+ else
519+ BLAS. hemm! (lrchar, ulchar, alpha, B, A, beta, C)
520+ end
521+ end
522+ Base. @constprop :aggressive function _symm_hemm_generic! (C, tA, tB, A, B, alpha, beta, :: Val{BlasFlag.NONE} )
523+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
524+ end
525+
465526# legacy method
466527Base. @constprop :aggressive generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
467528 _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
@@ -472,8 +533,8 @@ function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::S
472533 gemm_wrapper! (C, tA, tB, A, B, α, β)
473534end
474535Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
475- α :: Number , β :: Number , :: Val{false} ) where {T<: BlasReal }
476- _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (α, β) )
536+ alpha :: Number , beta :: Number , :: Val{false} ) where {T<: BlasReal }
537+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta )
477538end
478539# legacy method
479540Base. @constprop :aggressive generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
@@ -675,7 +736,7 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract
675736 if all (map (in ((' N' , ' T' , ' C' )), (tA_uc, tB_uc)))
676737 gemm_wrapper! (C, tA, tB, A, B, true , false )
677738 else
678- _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul () )
739+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), true , false )
679740 end
680741end
681742
@@ -702,7 +763,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
702763 _fullstride2 (A) && _fullstride2 (B) && _fullstride2 (C))
703764 return BLAS. gemm! (tA, tB, alpha, A, B, beta, C)
704765 end
705- _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul ( α, β) )
766+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), α, β)
706767end
707768# legacy method
708769gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
@@ -737,7 +798,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}
737798 BLAS. gemm! (tA, tB, alpha, reinterpret (T, A), B, beta, reinterpret (T, C))
738799 return C
739800 end
740- _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul ( α, β) )
801+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), α, β)
741802end
742803# legacy method
743804gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
@@ -908,12 +969,16 @@ end
908969# aggressive const prop makes mixed eltype mul!(C, A, B) invoke _generic_matmatmul! directly
909970# legacy method
910971Base. @constprop :aggressive generic_matmatmul! (C:: AbstractVecOrMat , tA, tB, A:: AbstractVecOrMat , B:: AbstractVecOrMat , _add:: MulAddMul = MulAddMul ()) =
911- _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
912- Base. @constprop :aggressive generic_matmatmul! (C:: AbstractVecOrMat , tA, tB, A:: AbstractVecOrMat , B:: AbstractVecOrMat , α:: Number , β:: Number ) =
913- _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (α, β))
972+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add. alpha, _add. beta)
973+ Base. @constprop :aggressive generic_matmatmul! (C:: AbstractVecOrMat , tA, tB, A:: AbstractVecOrMat , B:: AbstractVecOrMat , alpha:: Number , beta:: Number ) =
974+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
975+
976+ # legacy method
977+ _generic_matmatmul! (C:: AbstractVecOrMat , A:: AbstractVecOrMat , B:: AbstractVecOrMat , _add:: MulAddMul ) =
978+ _generic_matmatmul! (C, A, B, _add. alpha, _add. beta)
914979
915- @noinline function _generic_matmatmul! (C:: AbstractVecOrMat{R} , A:: AbstractVecOrMat{T} , B:: AbstractVecOrMat{S} ,
916- _add :: MulAddMul{ais1} ) where {T,S,R,ais1 }
980+ @noinline function _generic_matmatmul! (C:: AbstractVecOrMat{R} , A:: AbstractVecOrMat , B:: AbstractVecOrMat ,
981+ alpha :: Number , beta :: Number ) where {R }
917982 AxM = axes (A, 1 )
918983 AxK = axes (A, 2 ) # we use two `axes` calls in case of `AbstractVector`
919984 BxK = axes (B, 1 )
@@ -929,34 +994,33 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
929994 if BxN != CxN
930995 throw (DimensionMismatch (lazy " matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)" ))
931996 end
932- _rmul_alpha = MulAddMul {ais1,true,typeof(_add.alpha),Bool} (_add. alpha,false )
933997 if isbitstype (R) && sizeof (R) ≤ 16 && ! (A isa Adjoint || A isa Transpose)
934- _rmul_or_fill! (C, _add . beta)
935- (iszero (_add . alpha) || isempty (A) || isempty (B)) && return C
998+ _rmul_or_fill! (C, beta)
999+ (iszero (alpha) || isempty (A) || isempty (B)) && return C
9361000 @inbounds for n in BxN, k in BxK
9371001 # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
938- Balpha = _rmul_alpha (B[k,n])
1002+ Balpha = @stable_muladdmul MulAddMul (alpha, false ) (B[k,n])
9391003 @simd for m in AxM
9401004 C[m,n] = muladd (A[m,k], Balpha, C[m,n])
9411005 end
9421006 end
9431007 elseif isbitstype (R) && sizeof (R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
944- _rmul_or_fill! (C, _add . beta)
945- (iszero (_add . alpha) || isempty (A) || isempty (B)) && return C
1008+ _rmul_or_fill! (C, beta)
1009+ (iszero (alpha) || isempty (A) || isempty (B)) && return C
9461010 t = wrapperop (A)
9471011 pB = parent (B)
9481012 pA = parent (A)
9491013 tmp = similar (C, CxN)
9501014 ci = first (CxM)
951- ta = t (_add . alpha)
1015+ ta = t (alpha)
9521016 for i in AxM
9531017 mul! (tmp, pB, view (pA, :, i))
9541018 @views C[ci,:] .+ = t .(ta .* tmp)
9551019 ci += 1
9561020 end
9571021 else
958- if iszero (_add . alpha) || isempty (A) || isempty (B)
959- return _rmul_or_fill! (C, _add . beta)
1022+ if iszero (alpha) || isempty (A) || isempty (B)
1023+ return _rmul_or_fill! (C, beta)
9601024 end
9611025 a1 = first (AxK)
9621026 b1 = first (BxK)
@@ -966,7 +1030,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
9661030 @simd for k in AxK
9671031 Ctmp = muladd (A[i, k], B[k, j], Ctmp)
9681032 end
969- _modify! (_add , Ctmp, C, (i,j))
1033+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , Ctmp, C, (i,j))
9701034 end
9711035 end
9721036 return C
0 commit comments