Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move reinterpret-based optimization for complex matrix * real vec/mat to lower level. #44052

Merged
merged 3 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 82 additions & 48 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,16 @@ end
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, 'N', A, x, alpha, beta)

# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation.
for elty in (Float32, Float64)
@eval begin
@inline function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty},
alpha::Real, beta::Real)
Afl = reinterpret($elty, A)
yfl = reinterpret($elty, y)
mul!(yfl, Afl, x, alpha, beta)
return y
end
end
end
# Complex matrix times real vector.
# Reinterpret the matrix as a real matrix and do real matvec compuation.
@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, 'N', A, x, alpha, beta)

# Real matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, A isa StridedArray ? 'N' : 'T', A isa StridedArray ? A : parent(A), x, alpha, beta)

@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
Expand Down Expand Up @@ -191,18 +184,6 @@ end
(*)(A::AdjOrTransStridedMat{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) = copy(transpose(transpose(B) * parent(A)))
(*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::AdjOrTransStridedMat{<:BlasComplex}) = copy(wrapperop(B)(parent(B) * transpose(A)))

for elty in (Float32,Float64)
@eval begin
@inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty},
alpha::Real, beta::Real)
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
mul!(Cfl, Afl, B, alpha, beta)
return C
end
end
end

"""
muladd(A, y, z)

Expand Down Expand Up @@ -409,18 +390,14 @@ end
return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta))
end
end
# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency.
for elty in (Float32,Float64)
@eval begin
@inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, tB::Transpose{<:Any,<:StridedVecOrMat{$elty}},
alpha::Real, beta::Real)
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
mul!(Cfl, Afl, tB, alpha, beta)
return C
end
end
end
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemm_wrapper!(C, 'N', 'T', A, parent(tB), MulAddMul(alpha, beta))

# collapsing the following two defs with C::AbstractVecOrMat yields ambiguities
@inline mul!(C::AbstractVector, A::AbstractVecOrMat, tB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
Expand Down Expand Up @@ -512,22 +489,36 @@ end
function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{T},
α::Number=true, β::Number=false) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
if nA != length(x)
nA != length(x) &&
throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
end
if mA != length(y)
mA != length(y) &&
throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
return BLAS.gemv!(tA, alpha, A, x, beta, y)
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
if mA == 0
return y
end
if nA == 0
return _rmul_or_fill!(y, β)
end
end

function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
α::Number = true, β::Number = false) where {T<:BlasReal}
mA, nA = lapack_size(tA, A)
nA != length(x) &&
throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
mA != length(y) &&
throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
return BLAS.gemv!(tA, alpha, A, x, beta, y)
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) &&
stride(y, 1) == 1 && tA == 'N' # reinterpret-based optimization is valid only for contiguous `y`
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
return y
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
Expand Down Expand Up @@ -680,6 +671,49 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
generic_matmatmul!(C, tA, tB, A, B, _add)
end

function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasReal}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)

if nA != mB
throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
end

if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end

if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha)
if size(C) != (mA, nB)
throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
end
return _rmul_or_fill!(C, _add.beta)
end

if mA == 2 && nA == 2 && nB == 2
return matmul2x2!(C, tA, tB, A, B, _add)
end
if mA == 3 && nA == 3 && nB == 3
return matmul3x3!(C, tA, tB, A, B, _add)
end

alpha, beta = promote(_add.alpha, _add.beta, zero(T))

# Make-sure reinterpret-based optimization is BLAS-compatible.
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
stride(A, 2) >= size(A, 1) &&
stride(B, 2) >= size(B, 1) &&
stride(C, 2) >= size(C, 1)) && tA == 'N'
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
generic_matmatmul!(C, tA, tB, A, B, _add)
end

# blas.jl defines matmul for floats; other integer and mixed precision
# cases are handled here

Expand Down
59 changes: 31 additions & 28 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,34 +226,37 @@ end
end
end

@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T1 in (Float32, Float64)
for T2 in (Float32, Float64)
for arg1_real in (true, false)
@testset "Combination $T1 $T2 $arg1_real $arg2_real" for arg2_real in (true, false)
A0 = reshape(Vector{T1}(1:25), 5, 5) .+
(arg1_real ? 0 : 1im * reshape(Vector{T1}(-3:21), 5, 5))
A = view(A0, 1:2, 1:2)
B = Matrix{T2}([1.0 3.0; -1.0 2.0]) .+
(arg2_real ? 0 : 1im * Matrix{T2}([3.0 4; -1 10]))
AB_correct = copy(A) * B
AB = A * B # view times matrix
@test AB ≈ AB_correct
A1 = view(A0, :, 1:2) # rectangular view times matrix
@test A1 * B ≈ copy(A1) * B
B1 = view(B, 1:2, 1:2)
AB1 = A * B1 # view times view
@test AB1 ≈ AB_correct
x = Vector{T2}([1.0; 10.0]) .+ (arg2_real ? 0 : 1im * Vector{T2}([3; -1]))
Ax_exact = copy(A) * x
Ax = A * x # view times vector
@test Ax ≈ Ax_exact
x1 = view(x, 1:2)
Ax1 = A * x1 # view times viewed vector
@test Ax1 ≈ Ax_exact
@test copy(A) * x1 ≈ Ax_exact # matrix times viewed vector
# View times transposed matrix
Bt = transpose(B)
@test A * Bt ≈ A * copy(Bt)
@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T in (Float32, Float64)
A0 = randn(complex(T), 10, 10)
B0 = randn(T, 10, 10)
@testset "Combination Mat{$(complex(T))} Mat{$T}" for Bax1 in (1:5, 2:2:10), Bax2 in (1:5, 2:2:10)
B = view(A0, Bax1, Bax2)
tB = transpose(B)
Bd, tBd = copy(B), copy(tB)
for Aax1 in (1:5, 2:2:10, (:)), Aax2 in (1:5, 2:2:10)
A = view(A0, Aax1, Aax2)
AB_correct = copy(A) * Bd
AtB_correct = copy(A) * tBd
@test A*Bd ≈ AB_correct # view times matrix
@test A*B ≈ AB_correct # view times view
@test A*tBd ≈ AtB_correct # view times transposed matrix
@test A*tB ≈ AtB_correct # view times transposed view
end
end
x = randn(T, 10)
y0 = similar(A0, 20)
@testset "Combination Mat{$(complex(T))} Vec{$T}" for Aax1 in (1:5, 2:2:10, (:)), Aax2 in (1:5, 2:2:10)
A = view(A0, Aax1, Aax2)
Ad = copy(A)
for indx in (1:5, 1:2:10, 6:-1:2)
vx = view(x, indx)
dx = x[indx]
Ax_correct = Ad*dx
@test A*vx ≈ A*dx ≈ Ad*vx ≈ Ax_correct # view/matrix times view/vector
for indy in (1:2:2size(A,1), size(A,1):-1:1)
y = view(y0, indy)
@test mul!(y, A, vx) ≈ mul!(y, A, dx) ≈ mul!(y, Ad, vx) ≈
mul!(y, Ad, dx) ≈ Ax_correct # test for uncontiguous dest
end
end
end
Expand Down