Skip to content

Commit

Permalink
Extend lmul! for LowerTriangular
Browse files Browse the repository at this point in the history
  • Loading branch information
mjp98 committed Nov 2, 2022
1 parent 960aa87 commit 1f4f044
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
"""
eigen_blockwise(B::BlockDiagonal, args...; kwargs...) -> values, vectors
Computes the eigendecomposition for each block separately and keeps the block diagonal
Computes the eigendecomposition for each block separately and keeps the block diagonal
structure in the matrix of eigenvectors. Hence any parameters given are applied to each
eigendecomposition separately, but there is e.g. no global sorting of eigenvalues.
"""
Expand All @@ -58,16 +58,16 @@ function eigen_blockwise(B::BlockDiagonal, args...; kwargs...)
values = promote([e.values for e in eigens]...)
vectors = promote([e.vectors for e in eigens]...)
vcat(values...), BlockDiagonal([vectors...])
end
end

## This function never keeps the block diagonal structure
function LinearAlgebra.eigen(B::BlockDiagonal, args...; kwargs...)
values, vectors = eigen_blockwise(B, args...; kwargs...)
vectors = Matrix(vectors) # always convert to avoid type instability (also it speeds up the permutation step)
@static if VERSION > v"1.2.0-DEV.275"
Eigen(LinearAlgebra.sorteig!(values, vectors, kwargs...)...)
else
Eigen(values, vectors)
else
Eigen(values, vectors)
end
end

Expand Down Expand Up @@ -157,6 +157,20 @@ function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal, α::Number,
return C
end

function LinearAlgebra.lmul!(B::LowerTriangular{<:Any,<:BlockDiagonal}, vm::StridedVecOrMat)
row_i = 1
# BlockDiagonals with non-square blocks
if !all(BlockDiagonals.is_square, blocks(parent(B)))
return lmul!(LowerTriangular(Matrix(B)), vm) # Fallback on the generic LinearAlgebra method
end
for block in blocks(parent(B))
nrow = size(block, 1)
@views lmul!(LowerTriangular(block), vm[row_i:(row_i + nrow - 1), :])
row_i += nrow
end
vm
end

function LinearAlgebra.:\(B::BlockDiagonal, vm::AbstractVecOrMat)
row_i = 1
# BlockDiagonals with non-square blocks
Expand Down

0 comments on commit 1f4f044

Please sign in to comment.