From 34c036c1bf37e23a80a844e385167cd8c0a8964d Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 2 Jul 2024 14:16:50 +0530 Subject: [PATCH 1/6] Matmul: dispatch on specific blas paths using an enum --- stdlib/LinearAlgebra/src/matmul.jl | 61 +++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index f64422fd9cb8a..5e9ef5912dc7a 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -301,16 +301,36 @@ true """ @inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β) # Add a level of indirection and specialize _mul! to avoid ambiguities in mul! -@inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = +module BLASMatMul +@enum BlasFunction SYRK HERK GEMM NONE +end +@inline function _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) + tA = wrapper_char(A) + tB = wrapper_char(B) + tA_uc = uppercase(tA) + tB_uc = uppercase(tB) + isntc = wrapper_char_NTC(A) & wrapper_char_NTC(B) + blasfn = if (tA_uc == 'T' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'T') + BLASMatMul.SYRK + elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C') + BLASMatMul.HERK + elseif isntc + BLASMatMul.GEMM + else + BLASMatMul.NONE + end + generic_matmatmul_wrapper!( C, - wrapper_char(A), - wrapper_char(B), + tA, + tB, _unwrap(A), _unwrap(B), α, β, - Val(wrapper_char_NTC(A) & wrapper_char_NTC(B)) + Val(isntc), + Val(blasfn) ) +end # this indirection allows is to specialize on the types of the wrappers of A and B to some extent, # even though the wrappers are stripped off in mul! @@ -415,7 +435,7 @@ end # THE one big BLAS dispatch. This is split into two methods to improve latency Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{true}) where {T<:BlasFloat} + α::Number, β::Number, ::Val{true}, ::Val{blasfn}) where {T<:BlasFloat, blasfn} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) @@ -425,24 +445,29 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix return _rmul_or_fill!(C, β) end matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C - # We convert the chars to uppercase to potentially unwrap a WrapperChar, - # and extract the char corresponding to the wrapper type - tA_uc, tB_uc = uppercase(tA), uppercase(tB) - # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. - if tA_uc == 'T' && tB_uc == 'N' && A === B - return syrk_wrapper!(C, 'T', A, α, β) - elseif tA_uc == 'N' && tB_uc == 'T' && A === B - return syrk_wrapper!(C, 'N', A, α, β) - elseif tA_uc == 'C' && tB_uc == 'N' && A === B - return herk_wrapper!(C, 'C', A, α, β) - elseif tA_uc == 'N' && tB_uc == 'C' && A === B - return herk_wrapper!(C, 'N', A, α, β) + _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, Val(blasfn)) +end +Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BLASMatMul.SYRK}) + tA_uc = uppercase(tA) + if A === B + return syrk_wrapper!(C, tA_uc, A, α, β) else return gemm_wrapper!(C, tA, tB, A, B, α, β) end end +Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BLASMatMul.HERK}) + tA_uc = uppercase(tA) + if A === B + return herk_wrapper!(C, tA_uc, A, α, β) + else + return gemm_wrapper!(C, tA, tB, A, B, α, β) + end +end +Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BLASMatMul.GEMM}) + return gemm_wrapper!(C, tA, tB, A, B, α, β) +end Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{false}) where {T<:BlasFloat} + α::Number, β::Number, ::Val{false}, @nospecialize(::Val)) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) From eb146f67011f59a11aefa876cc6bc0d8399d1250 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 2 Jul 2024 22:31:58 +0530 Subject: [PATCH 2/6] Symm/Hemm flags --- stdlib/LinearAlgebra/src/matmul.jl | 94 +++++++++++++++++++----------- 1 file changed, 61 insertions(+), 33 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 5e9ef5912dc7a..243e846cb8b42 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -301,8 +301,10 @@ true """ @inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β) # Add a level of indirection and specialize _mul! to avoid ambiguities in mul! -module BLASMatMul -@enum BlasFunction SYRK HERK GEMM NONE +module BlasFlag +@enum BlasFunction SYRK HERK GEMM SYMM HEMM NONE +const ValSyrkHerkGemm = Union{Val{SYRK}, Val{HERK}, Val{GEMM}} +const ValSymmHemmGeneric = Union{Val{SYMM}, Val{HEMM}, Val{NONE}} end @inline function _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) tA = wrapper_char(A) @@ -310,14 +312,22 @@ end tA_uc = uppercase(tA) tB_uc = uppercase(tB) isntc = wrapper_char_NTC(A) & wrapper_char_NTC(B) - blasfn = if (tA_uc == 'T' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'T') - BLASMatMul.SYRK - elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C') - BLASMatMul.HERK - elseif isntc - BLASMatMul.GEMM + blasfn = if isntc + if (tA_uc == 'T' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'T') + BlasFlag.SYRK + elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C') + BlasFlag.HERK + else isntc + BlasFlag.GEMM + end else - BLASMatMul.NONE + if (tA_uc == 'S' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'S') + BlasFlag.SYMM + elseif (tA_uc == 'H' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'H') + BlasFlag.HEMM + else + BlasFlag.NONE + end end generic_matmatmul_wrapper!( @@ -327,8 +337,7 @@ end _unwrap(A), _unwrap(B), α, β, - Val(isntc), - Val(blasfn) + Val(blasfn), ) end @@ -435,7 +444,7 @@ end # THE one big BLAS dispatch. This is split into two methods to improve latency Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{true}, ::Val{blasfn}) where {T<:BlasFloat, blasfn} + α::Number, β::Number, val::BlasFlag.ValSyrkHerkGemm) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) @@ -445,29 +454,31 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix return _rmul_or_fill!(C, β) end matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C - _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, Val(blasfn)) + _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val) + return C end -Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BLASMatMul.SYRK}) - tA_uc = uppercase(tA) +Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK}) if A === B + tA_uc = uppercase(tA) # potentially strip a WrapperChar return syrk_wrapper!(C, tA_uc, A, α, β) else return gemm_wrapper!(C, tA, tB, A, B, α, β) end end -Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BLASMatMul.HERK}) - tA_uc = uppercase(tA) +Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.HERK}) if A === B + tA_uc = uppercase(tA) # potentially strip a WrapperChar return herk_wrapper!(C, tA_uc, A, α, β) else return gemm_wrapper!(C, tA, tB, A, B, α, β) end end -Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BLASMatMul.GEMM}) +Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.GEMM}) return gemm_wrapper!(C, tA, tB, A, B, α, β) end +_valtypeparam(v::Val{T}) where {T} = T Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{false}, @nospecialize(::Val)) where {T<:BlasFloat} + α::Number, β::Number, val::BlasFlag.ValSymmHemmGeneric) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) @@ -477,23 +488,40 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix return _rmul_or_fill!(C, β) end matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C - # We convert the chars to uppercase to potentially unwrap a WrapperChar, - # and extract the char corresponding to the wrapper type - tA_uc, tB_uc = uppercase(tA), uppercase(tB) alpha, beta = promote(α, β, zero(T)) - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - if tA_uc == 'S' && tB_uc == 'N' - return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) - elseif tA_uc == 'N' && tB_uc == 'S' - return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) - elseif tA_uc == 'H' && tB_uc == 'N' - return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C) - elseif tA_uc == 'N' && tB_uc == 'H' - return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) - end + blasfn = _valtypeparam(val) + if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && blasfn ∈ (BlasFlag.SYMM, BlasFlag.HEMM) + _blasfn = blasfn + αβ = (alpha, beta) + else + _blasfn = BlasFlag.NONE + αβ = (α, β) + end + _symm_hemm_generic!(C, tA, tB, A, B, αβ..., Val(_blasfn)) + return C +end +function _lrchar_ulchar(tA, tB) + if uppercase(tA) == 'N' + lrchar = 'R' + ulchar = isuppercase(tB) ? 'U' : 'L' + else + lrchar = 'L' + ulchar = isuppercase(tA) ? 'U' : 'L' end - return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) + return lrchar, ulchar end +function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.SYMM}) + lrchar, ulchar = _lrchar_ulchar(tA, tB) + BLAS.symm!(lrchar, ulchar, alpha, A, B, beta, C) +end +function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.HEMM}) + lrchar, ulchar = _lrchar_ulchar(tA, tB) + BLAS.hemm!(lrchar, ulchar, alpha, A, B, beta, C) +end +Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.NONE}) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) +end + # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = From bb02c14aae622f334398c20d82d121e28a2ecd93 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 3 Jul 2024 12:58:37 +0530 Subject: [PATCH 3/6] Fix branches in symm/hemm --- stdlib/LinearAlgebra/src/matmul.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 243e846cb8b42..99e77a38185a1 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -512,11 +512,19 @@ function _lrchar_ulchar(tA, tB) end function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.SYMM}) lrchar, ulchar = _lrchar_ulchar(tA, tB) - BLAS.symm!(lrchar, ulchar, alpha, A, B, beta, C) + if lrchar == 'L' + BLAS.symm!(lrchar, ulchar, alpha, A, B, beta, C) + else + BLAS.symm!(lrchar, ulchar, alpha, B, A, beta, C) + end end function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.HEMM}) lrchar, ulchar = _lrchar_ulchar(tA, tB) - BLAS.hemm!(lrchar, ulchar, alpha, A, B, beta, C) + if lrchar == 'L' + BLAS.hemm!(lrchar, ulchar, alpha, A, B, beta, C) + else + BLAS.hemm!(lrchar, ulchar, alpha, B, A, beta, C) + end end Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.NONE}) _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) From cbdea0a8617a13a1b904d9bc9a6f7e06410c0a93 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 23 Oct 2024 16:23:35 +0530 Subject: [PATCH 4/6] Add `@stable_muladdmul` to `_symm_hemm_generic!` for `NONE` --- stdlib/LinearAlgebra/src/matmul.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 99e77a38185a1..d45f3d63ad493 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -500,7 +500,7 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix _symm_hemm_generic!(C, tA, tB, A, B, αβ..., Val(_blasfn)) return C end -function _lrchar_ulchar(tA, tB) +Base.@constprop :aggressive function _lrchar_ulchar(tA, tB) if uppercase(tA) == 'N' lrchar = 'R' ulchar = isuppercase(tB) ? 'U' : 'L' @@ -526,8 +526,8 @@ function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.HEMM}) BLAS.hemm!(lrchar, ulchar, alpha, B, A, beta, C) end end -Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.NONE}) - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) +Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.NONE}) + @stable_muladdmul _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(alpha, beta)) end # legacy method From 05452f4ed8effc895072cc1d5144ffda9ac84664 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 23 Oct 2024 17:38:45 +0530 Subject: [PATCH 5/6] alpha,beta instead of MulAddMul in _generic_matmatmul! --- stdlib/LinearAlgebra/src/matmul.jl | 45 ++++++++++++++++-------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index d45f3d63ad493..19c16f8957871 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -527,7 +527,7 @@ function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.HEMM}) end end Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.NONE}) - @stable_muladdmul _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(alpha, beta)) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) end # legacy method @@ -540,8 +540,8 @@ function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::S gemm_wrapper!(C, tA, tB, A, B, α, β) end Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{false}) where {T<:BlasReal} - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) + alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal} + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) end # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, @@ -743,7 +743,7 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) gemm_wrapper!(C, tA, tB, A, B, true, false) else - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul()) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), true, false) end end @@ -770,7 +770,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab _fullstride2(A) && _fullstride2(B) && _fullstride2(C)) return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) end - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) end # legacy method gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, @@ -805,7 +805,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T} BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) end # legacy method gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, @@ -975,12 +975,16 @@ end # aggressive const prop makes mixed eltype mul!(C, A, B) invoke _generic_matmatmul! directly # legacy method Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul()) = - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) -Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta) +Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, alpha::Number, beta::Number) = + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) -@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, - _add::MulAddMul{ais1}) where {T,S,R,ais1} +# legacy method +_generic_matmatmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) = + _generic_matmatmul!(C, A, B, _add.alpha, _add.beta) + +@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat, B::AbstractVecOrMat, + alpha::Number, beta::Number) where {R} AxM = axes(A, 1) AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector` BxK = axes(B, 1) @@ -996,34 +1000,33 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A if BxN != CxN throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) end - _rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false) if isbitstype(R) && sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose) - _rmul_or_fill!(C, _add.beta) - (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C + _rmul_or_fill!(C, beta) + (iszero(alpha) || isempty(A) || isempty(B)) && return C @inbounds for n in BxN, k in BxK # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) - Balpha = _rmul_alpha(B[k,n]) + Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n]) @simd for m in AxM C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end end elseif isbitstype(R) && sizeof(R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose)) - _rmul_or_fill!(C, _add.beta) - (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C + _rmul_or_fill!(C, beta) + (iszero(alpha) || isempty(A) || isempty(B)) && return C t = wrapperop(A) pB = parent(B) pA = parent(A) tmp = similar(C, CxN) ci = first(CxM) - ta = t(_add.alpha) + ta = t(alpha) for i in AxM mul!(tmp, pB, view(pA, :, i)) @views C[ci,:] .+= t.(ta .* tmp) ci += 1 end else - if iszero(_add.alpha) || isempty(A) || isempty(B) - return _rmul_or_fill!(C, _add.beta) + if iszero(alpha) || isempty(A) || isempty(B) + return _rmul_or_fill!(C, beta) end a1 = first(AxK) b1 = first(BxK) @@ -1033,7 +1036,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A @simd for k in AxK Ctmp = muladd(A[i, k], B[k, j], Ctmp) end - _modify!(_add, Ctmp, C, (i,j)) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), Ctmp, C, (i,j)) end end return C From 1dc825e830afd981950e59bda933a87db8a23cf9 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 24 Oct 2024 15:40:50 +0530 Subject: [PATCH 6/6] Remove Val from BlasFlag union variable --- stdlib/LinearAlgebra/src/matmul.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 19c16f8957871..32378212d7346 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -303,8 +303,8 @@ true # Add a level of indirection and specialize _mul! to avoid ambiguities in mul! module BlasFlag @enum BlasFunction SYRK HERK GEMM SYMM HEMM NONE -const ValSyrkHerkGemm = Union{Val{SYRK}, Val{HERK}, Val{GEMM}} -const ValSymmHemmGeneric = Union{Val{SYMM}, Val{HEMM}, Val{NONE}} +const SyrkHerkGemm = Union{Val{SYRK}, Val{HERK}, Val{GEMM}} +const SymmHemmGeneric = Union{Val{SYMM}, Val{HEMM}, Val{NONE}} end @inline function _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) tA = wrapper_char(A) @@ -444,7 +444,7 @@ end # THE one big BLAS dispatch. This is split into two methods to improve latency Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, val::BlasFlag.ValSyrkHerkGemm) where {T<:BlasFloat} + α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) @@ -478,7 +478,7 @@ Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, end _valtypeparam(v::Val{T}) where {T} = T Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, val::BlasFlag.ValSymmHemmGeneric) where {T<:BlasFloat} + α::Number, β::Number, val::BlasFlag.SymmHemmGeneric) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)