Skip to content

Commit bf6da77

Browse files
authored
Matmul: dispatch on specific blas paths using an enum (#55002)
This expands on the approach taken by #54552. We pass on more type information to `generic_matmatmul_wrapper!`, which lets us convert the branches to method dispatches. This helps spread the latency around, so that instead of compiling all the branches in the first call, we now compile the branches only when they are actually taken. While this reduces the latency in individual branches, there is no reduction in latency if all the branches are reachable. ```julia julia> A = rand(2,2); julia> @time A * A; 0.479805 seconds (809.66 k allocations: 40.764 MiB, 99.93% compilation time) # 1.12.0-DEV.806 0.346739 seconds (633.17 k allocations: 31.320 MiB, 99.90% compilation time) # This PR julia> @time A * A'; 0.030413 seconds (101.98 k allocations: 5.359 MiB, 98.54% compilation time) # v1.12.0-DEV.806 0.148118 seconds (219.51 k allocations: 11.652 MiB, 99.72% compilation time) # This PR ``` The latency is spread between the two calls here. In fresh sessions: ```julia julia> A = rand(2,2); julia> @time A * A'; 0.473630 seconds (825.65 k allocations: 41.554 MiB, 99.91% compilation time) # v1.12.0-DEV.806 0.490305 seconds (774.87 k allocations: 38.824 MiB, 99.90% compilation time) # This PR ``` In this case, both the `syrk` and `gemm` branches are reachable, so there is no reduction in latency. Analogously, there is a reduction in latency in the second set of matrix multiplications where we call `symm!/hemm!` or `_generic_matmatmul`: ```julia julia> using LinearAlgebra julia> A = rand(2,2); julia> @time Symmetric(A) * A; 0.711178 seconds (2.06 M allocations: 103.878 MiB, 2.20% gc time, 99.98% compilation time) # v1.12.0-DEV.806 0.540669 seconds (904.12 k allocations: 43.576 MiB, 2.60% gc time, 97.36% compilation time) # This PR ```
1 parent 2a06376 commit bf6da77

File tree

1 file changed

+116
-52
lines changed

1 file changed

+116
-52
lines changed

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 116 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -294,16 +294,45 @@ true
294294
"""
295295
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β)
296296
# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
297-
@inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
297+
module BlasFlag
298+
@enum BlasFunction SYRK HERK GEMM SYMM HEMM NONE
299+
const SyrkHerkGemm = Union{Val{SYRK}, Val{HERK}, Val{GEMM}}
300+
const SymmHemmGeneric = Union{Val{SYMM}, Val{HEMM}, Val{NONE}}
301+
end
302+
@inline function _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number)
303+
tA = wrapper_char(A)
304+
tB = wrapper_char(B)
305+
tA_uc = uppercase(tA)
306+
tB_uc = uppercase(tB)
307+
isntc = wrapper_char_NTC(A) & wrapper_char_NTC(B)
308+
blasfn = if isntc
309+
if (tA_uc == 'T' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'T')
310+
BlasFlag.SYRK
311+
elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C')
312+
BlasFlag.HERK
313+
else isntc
314+
BlasFlag.GEMM
315+
end
316+
else
317+
if (tA_uc == 'S' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'S')
318+
BlasFlag.SYMM
319+
elseif (tA_uc == 'H' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'H')
320+
BlasFlag.HEMM
321+
else
322+
BlasFlag.NONE
323+
end
324+
end
325+
298326
generic_matmatmul_wrapper!(
299327
C,
300-
wrapper_char(A),
301-
wrapper_char(B),
328+
tA,
329+
tB,
302330
_unwrap(A),
303331
_unwrap(B),
304332
α, β,
305-
Val(wrapper_char_NTC(A) & wrapper_char_NTC(B))
333+
Val(blasfn),
306334
)
335+
end
307336

308337
# this indirection allows is to specialize on the types of the wrappers of A and B to some extent,
309338
# even though the wrappers are stripped off in mul!
@@ -408,7 +437,7 @@ end
408437

409438
# THE one big BLAS dispatch. This is split into two methods to improve latency
410439
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
411-
α::Number, β::Number, ::Val{true}) where {T<:BlasFloat}
440+
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat}
412441
mA, nA = lapack_size(tA, A)
413442
mB, nB = lapack_size(tB, B)
414443
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
@@ -418,24 +447,31 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
418447
return _rmul_or_fill!(C, β)
419448
end
420449
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
421-
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
422-
# and extract the char corresponding to the wrapper type
423-
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
424-
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
425-
if tA_uc == 'T' && tB_uc == 'N' && A === B
426-
return syrk_wrapper!(C, 'T', A, α, β)
427-
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
428-
return syrk_wrapper!(C, 'N', A, α, β)
429-
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
430-
return herk_wrapper!(C, 'C', A, α, β)
431-
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
432-
return herk_wrapper!(C, 'N', A, α, β)
450+
_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val)
451+
return C
452+
end
453+
Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK})
454+
if A === B
455+
tA_uc = uppercase(tA) # potentially strip a WrapperChar
456+
return syrk_wrapper!(C, tA_uc, A, α, β)
433457
else
434458
return gemm_wrapper!(C, tA, tB, A, B, α, β)
435459
end
436460
end
461+
Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.HERK})
462+
if A === B
463+
tA_uc = uppercase(tA) # potentially strip a WrapperChar
464+
return herk_wrapper!(C, tA_uc, A, α, β)
465+
else
466+
return gemm_wrapper!(C, tA, tB, A, B, α, β)
467+
end
468+
end
469+
Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.GEMM})
470+
return gemm_wrapper!(C, tA, tB, A, B, α, β)
471+
end
472+
_valtypeparam(v::Val{T}) where {T} = T
437473
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
438-
α::Number, β::Number, ::Val{false}) where {T<:BlasFloat}
474+
α::Number, β::Number, val::BlasFlag.SymmHemmGeneric) where {T<:BlasFloat}
439475
mA, nA = lapack_size(tA, A)
440476
mB, nB = lapack_size(tB, B)
441477
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
@@ -445,23 +481,48 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
445481
return _rmul_or_fill!(C, β)
446482
end
447483
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
448-
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
449-
# and extract the char corresponding to the wrapper type
450-
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
451484
alpha, beta = promote(α, β, zero(T))
452-
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
453-
if tA_uc == 'S' && tB_uc == 'N'
454-
return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C)
455-
elseif tA_uc == 'N' && tB_uc == 'S'
456-
return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C)
457-
elseif tA_uc == 'H' && tB_uc == 'N'
458-
return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C)
459-
elseif tA_uc == 'N' && tB_uc == 'H'
460-
return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
461-
end
485+
blasfn = _valtypeparam(val)
486+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && blasfn (BlasFlag.SYMM, BlasFlag.HEMM)
487+
_blasfn = blasfn
488+
αβ = (alpha, beta)
489+
else
490+
_blasfn = BlasFlag.NONE
491+
αβ = (α, β)
462492
end
463-
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
493+
_symm_hemm_generic!(C, tA, tB, A, B, αβ..., Val(_blasfn))
494+
return C
464495
end
496+
Base.@constprop :aggressive function _lrchar_ulchar(tA, tB)
497+
if uppercase(tA) == 'N'
498+
lrchar = 'R'
499+
ulchar = isuppercase(tB) ? 'U' : 'L'
500+
else
501+
lrchar = 'L'
502+
ulchar = isuppercase(tA) ? 'U' : 'L'
503+
end
504+
return lrchar, ulchar
505+
end
506+
function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.SYMM})
507+
lrchar, ulchar = _lrchar_ulchar(tA, tB)
508+
if lrchar == 'L'
509+
BLAS.symm!(lrchar, ulchar, alpha, A, B, beta, C)
510+
else
511+
BLAS.symm!(lrchar, ulchar, alpha, B, A, beta, C)
512+
end
513+
end
514+
function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.HEMM})
515+
lrchar, ulchar = _lrchar_ulchar(tA, tB)
516+
if lrchar == 'L'
517+
BLAS.hemm!(lrchar, ulchar, alpha, A, B, beta, C)
518+
else
519+
BLAS.hemm!(lrchar, ulchar, alpha, B, A, beta, C)
520+
end
521+
end
522+
Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.NONE})
523+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
524+
end
525+
465526
# legacy method
466527
Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
467528
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
@@ -472,8 +533,8 @@ function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::S
472533
gemm_wrapper!(C, tA, tB, A, B, α, β)
473534
end
474535
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
475-
α::Number, β::Number, ::Val{false}) where {T<:BlasReal}
476-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
536+
alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal}
537+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
477538
end
478539
# legacy method
479540
Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
@@ -675,7 +736,7 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract
675736
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
676737
gemm_wrapper!(C, tA, tB, A, B, true, false)
677738
else
678-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul())
739+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), true, false)
679740
end
680741
end
681742

@@ -702,7 +763,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
702763
_fullstride2(A) && _fullstride2(B) && _fullstride2(C))
703764
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
704765
end
705-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
766+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
706767
end
707768
# legacy method
708769
gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
@@ -737,7 +798,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}
737798
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
738799
return C
739800
end
740-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
801+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
741802
end
742803
# legacy method
743804
gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
@@ -908,12 +969,16 @@ end
908969
# aggressive const prop makes mixed eltype mul!(C, A, B) invoke _generic_matmatmul! directly
909970
# legacy method
910971
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul()) =
911-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
912-
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
913-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
972+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta)
973+
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, alpha::Number, beta::Number) =
974+
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
975+
976+
# legacy method
977+
_generic_matmatmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) =
978+
_generic_matmatmul!(C, A, B, _add.alpha, _add.beta)
914979

915-
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
916-
_add::MulAddMul{ais1}) where {T,S,R,ais1}
980+
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat, B::AbstractVecOrMat,
981+
alpha::Number, beta::Number) where {R}
917982
AxM = axes(A, 1)
918983
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
919984
BxK = axes(B, 1)
@@ -929,34 +994,33 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
929994
if BxN != CxN
930995
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
931996
end
932-
_rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false)
933997
if isbitstype(R) && sizeof(R) 16 && !(A isa Adjoint || A isa Transpose)
934-
_rmul_or_fill!(C, _add.beta)
935-
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
998+
_rmul_or_fill!(C, beta)
999+
(iszero(alpha) || isempty(A) || isempty(B)) && return C
9361000
@inbounds for n in BxN, k in BxK
9371001
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
938-
Balpha = _rmul_alpha(B[k,n])
1002+
Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n])
9391003
@simd for m in AxM
9401004
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
9411005
end
9421006
end
9431007
elseif isbitstype(R) && sizeof(R) 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
944-
_rmul_or_fill!(C, _add.beta)
945-
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
1008+
_rmul_or_fill!(C, beta)
1009+
(iszero(alpha) || isempty(A) || isempty(B)) && return C
9461010
t = wrapperop(A)
9471011
pB = parent(B)
9481012
pA = parent(A)
9491013
tmp = similar(C, CxN)
9501014
ci = first(CxM)
951-
ta = t(_add.alpha)
1015+
ta = t(alpha)
9521016
for i in AxM
9531017
mul!(tmp, pB, view(pA, :, i))
9541018
@views C[ci,:] .+= t.(ta .* tmp)
9551019
ci += 1
9561020
end
9571021
else
958-
if iszero(_add.alpha) || isempty(A) || isempty(B)
959-
return _rmul_or_fill!(C, _add.beta)
1022+
if iszero(alpha) || isempty(A) || isempty(B)
1023+
return _rmul_or_fill!(C, beta)
9601024
end
9611025
a1 = first(AxK)
9621026
b1 = first(BxK)
@@ -966,7 +1030,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
9661030
@simd for k in AxK
9671031
Ctmp = muladd(A[i, k], B[k, j], Ctmp)
9681032
end
969-
_modify!(_add, Ctmp, C, (i,j))
1033+
@stable_muladdmul _modify!(MulAddMul(alpha,beta), Ctmp, C, (i,j))
9701034
end
9711035
end
9721036
return C

0 commit comments

Comments
 (0)