Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Construct MulAddMul at gemm_wrapper! call sites #34601

Merged
merged 3 commits into from
Feb 1, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,23 @@ struct MulAddMul{ais1, bis0, TA, TB}
beta::TB
end

MulAddMul(alpha::TA, beta::TB) where {TA, TB} =
MulAddMul{isone(alpha), iszero(beta), TA, TB}(alpha, beta)
@inline function MulAddMul(alpha::TA, beta::TB) where {TA,TB}
if isone(alpha)
if iszero(beta)
return MulAddMul{true,true,TA,TB}(alpha, beta)
else
return MulAddMul{true,false,TA,TB}(alpha, beta)
end
else
if iszero(beta)
return MulAddMul{false,true,TA,TB}(alpha, beta)
else
return MulAddMul{false,false,TA,TB}(alpha, beta)
end
end
end

MulAddMul() = MulAddMul(true, false)
MulAddMul() = MulAddMul{true,true}(true, false)
tkf marked this conversation as resolved.
Show resolved Hide resolved

@inline (::MulAddMul{true})(x) = x
@inline (p::MulAddMul{false})(x) = x * p.alpha
Expand Down
34 changes: 17 additions & 17 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ end

@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasFloat}
return gemm_wrapper!(C, 'N', 'N', A, B, alpha, beta)
return gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
end
# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
# first matrix as a real matrix and carry out real matrix matrix multiply
Expand Down Expand Up @@ -307,7 +307,7 @@ lmul!(A, B)
if A===B
return syrk_wrapper!(C, 'T', A, alpha, beta)
else
return gemm_wrapper!(C, 'T', 'N', A, B, alpha, beta)
return gemm_wrapper!(C, 'T', 'N', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
Expand All @@ -322,7 +322,7 @@ end
if A===B
return syrk_wrapper!(C, 'N', A, alpha, beta)
else
return gemm_wrapper!(C, 'N', 'T', A, B, alpha, beta)
return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta))
end
end
# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency.
Expand All @@ -349,7 +349,7 @@ end
alpha::Number, beta::Number) where {T<:BlasFloat}
A = transA.parent
B = transB.parent
return gemm_wrapper!(C, 'T', 'T', A, B, alpha, beta)
return gemm_wrapper!(C, 'T', 'T', A, B, MulAddMul(alpha, beta))
end
@inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number)
Expand All @@ -362,7 +362,7 @@ end
alpha::Number, beta::Number) where {T<:BlasFloat}
A = transA.parent
B = transB.parent
return gemm_wrapper!(C, 'T', 'C', A, B, alpha, beta)
return gemm_wrapper!(C, 'T', 'C', A, B, MulAddMul(alpha, beta))
end
@inline function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number)
Expand All @@ -382,7 +382,7 @@ end
if A===B
return herk_wrapper!(C, 'C', A, alpha, beta)
else
return gemm_wrapper!(C, 'C', 'N', A, B, alpha, beta)
return gemm_wrapper!(C, 'C', 'N', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
Expand All @@ -402,7 +402,7 @@ end
if A === B
return herk_wrapper!(C, 'N', A, alpha, beta)
else
return gemm_wrapper!(C, 'N', 'C', A, B, alpha, beta)
return gemm_wrapper!(C, 'N', 'C', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
Expand All @@ -415,7 +415,7 @@ end
alpha::Number, beta::Number) where {T<:BlasFloat}
A = adjA.parent
B = adjB.parent
return gemm_wrapper!(C, 'C', 'C', A, B, alpha, beta)
return gemm_wrapper!(C, 'C', 'C', A, B, MulAddMul(alpha, beta))
end
@inline function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number)
Expand Down Expand Up @@ -508,7 +508,7 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
end
end
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
return gemm_wrapper!(C, tA, tAt, A, A, MulAddMul(α, β))
end

function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
Expand Down Expand Up @@ -547,7 +547,7 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
end
end
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
return gemm_wrapper!(C, tA, tAt, A, A, MulAddMul(α, β))
end

function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
Expand All @@ -561,7 +561,7 @@ end

function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number=true, β::Number=false) where {T<:BlasFloat}
_add = MulAddMul()) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)

Expand All @@ -573,21 +573,21 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end

if mA == 0 || nA == 0 || nB == 0 || iszero(α)
if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha)
if size(C) != (mA, nB)
throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)"))
end
return _rmul_or_fill!(C, β)
return _rmul_or_fill!(C, _add.beta)
end

if mA == 2 && nA == 2 && nB == 2
return matmul2x2!(C, tA, tB, A, B, MulAddMul(α, β))
return matmul2x2!(C, tA, tB, A, B, _add)
end
if mA == 3 && nA == 3 && nB == 3
return matmul3x3!(C, tA, tB, A, B, MulAddMul(α, β))
return matmul3x3!(C, tA, tB, A, B, _add)
end

alpha, beta = promote(α, β, zero(T))
alpha, beta = promote(_add.alpha, _add.beta, zero(T))
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
Expand All @@ -596,7 +596,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
stride(C, 2) >= size(C, 1))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
generic_matmatmul!(C, tA, tB, A, B, MulAddMul(α, β))
generic_matmatmul!(C, tA, tB, A, B, _add)
end

# blas.jl defines matmul for floats; other integer and mixed precision
Expand Down