From 723649f96ca4a9274a68ac5e88eb9e8dd25ea375 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 10 Oct 2024 18:37:17 +0530 Subject: [PATCH 1/3] Call `MulAddMul` instead of multiplication in _generic_matmatmul! --- stdlib/LinearAlgebra/src/matmul.jl | 5 +++-- stdlib/LinearAlgebra/test/matmul.jl | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index b70f7d47b28dd..d305b37758950 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -919,7 +919,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) @noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, - _add::MulAddMul) where {T,S,R} + _add::MulAddMul{ais1,bis0}) where {T,S,R,ais1,bis0} AxM = axes(A, 1) AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector` BxK = axes(B, 1) @@ -935,11 +935,12 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A if BxN != CxN throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) end + _add_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false) if isbitstype(R) && sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose) _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C @inbounds for n in BxN, k in BxK - Balpha = B[k,n]*_add.alpha + Balpha = _add_alpha(B[k,n]) @simd for m in AxM C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 4c79451ebfc8b..0d1e2776d2bb3 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -1130,4 +1130,22 @@ end @test a * transpose(B) ≈ A * transpose(B) end +@testset "issue #56085" begin + struct Thing + data::Float64 + end + + Base.zero(::Type{Thing}) = Thing(0.) + Base.zero(::Thing) = Thing(0.) + Base.one(::Type{Thing}) = Thing(1.) + Base.one(::Thing) = Thing(1.) + Base.:+(t::Thing...) = +(getfield.(t, :data)...) + Base.:*(t::Thing...) = *(getfield.(t, :data)...) + + M = Float64[1 2; 3 4] + A = Thing.(M) + + @test A * A ≈ M * M +end + end # module TestMatmul From 9ad92bd6756fd56bbfc2dbfbe567f70c8af92efb Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 10 Oct 2024 18:39:33 +0530 Subject: [PATCH 2/3] Remove bis0 param from method signature --- stdlib/LinearAlgebra/src/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index d305b37758950..5085388a827ca 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -919,7 +919,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) @noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, - _add::MulAddMul{ais1,bis0}) where {T,S,R,ais1,bis0} + _add::MulAddMul{ais1}) where {T,S,R,ais1} AxM = axes(A, 1) AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector` BxK = axes(B, 1) From 2a6ce58958c37b804034272d42eb5c8f93e77cb1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 13 Oct 2024 10:47:02 +0530 Subject: [PATCH 3/3] Add comments --- stdlib/LinearAlgebra/src/matmul.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 5085388a827ca..02ecd74152531 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -935,12 +935,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A if BxN != CxN throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) end - _add_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false) + _rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false) if isbitstype(R) && sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose) _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C @inbounds for n in BxN, k in BxK - Balpha = _add_alpha(B[k,n]) + # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) + Balpha = _rmul_alpha(B[k,n]) @simd for m in AxM C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end