Skip to content

Commit

Permalink
Merge branch 'master' into mz/promote
Browse files Browse the repository at this point in the history
  • Loading branch information
mzgubic authored May 13, 2022
2 parents 013b9be + 57e4ec7 commit 3b4e37b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockDiagonals"
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.27"
version = "0.1.29"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
10 changes: 4 additions & 6 deletions src/blockdiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,15 @@ end
# that the same entry is available via `getblock(B, p)[i, end+j]`; `p = -1` if no such `p`.
function _block_indices(B::BlockDiagonal, i::Integer, j::Integer)
all((0, 0) .< (i, j) .<= size(B)) || throw(BoundsError(B, (i, j)))
nrows = size.(blocks(B), 1)
ncols = size.(blocks(B), 2)
# find the on-diagonal block `p` in column `j`
p = 0
while j > 0
@inbounds while j > 0
p += 1
j -= ncols[p]
j -= blocksize(B, p)[2]
end
i -= sum(nrows[1:(p-1)])
@views @inbounds i -= sum(size.(blocks(B)[1:(p-1)], 1))
# if row `i` outside of block `p`, set `p` to place-holder value `-1`
if i <= 0 || i > nrows[p]
if i <= 0 || i > blocksize(B, p)[1]
p = -1
end
return p, i, j
Expand Down
11 changes: 9 additions & 2 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ for f in (:adjoint, :eigvecs, :inv, :pinv, :transpose)
@eval LinearAlgebra.$f(B::BlockDiagonal) = BlockDiagonal(map($f, blocks(B)))
end

LinearAlgebra.diag(B::BlockDiagonal) = mapreduce(diag, vcat, blocks(B))
LinearAlgebra.diag(B::BlockDiagonal) = map(i -> getindex(B, i, i), 1:minimum(size(B)))
LinearAlgebra.det(B::BlockDiagonal) = prod(det, blocks(B))
LinearAlgebra.logdet(B::BlockDiagonal) = sum(logdet, blocks(B))
LinearAlgebra.tr(B::BlockDiagonal) = sum(tr, blocks(B))
Expand Down Expand Up @@ -73,7 +73,14 @@ end


svdvals_blockwise(B::BlockDiagonal) = mapreduce(svdvals, vcat, blocks(B))
LinearAlgebra.svdvals(B::BlockDiagonal) = sort!(svdvals_blockwise(B); rev=true)
function LinearAlgebra.svdvals(B::BlockDiagonal)
# if all the blocks are squares
if all(map(t -> t[1] == t[2], blocksizes(B)))
return sort!(svdvals_blockwise(B); rev=true)
else
return svdvals(Matrix(B))
end
end

# `B = U * Diagonal(S) * Vt` with `U` and `Vt` `BlockDiagonal` (`S` only sorted block-wise).
function svd_blockwise(B::BlockDiagonal{T}; full::Bool=false) where T
Expand Down
5 changes: 5 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ end
b1 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)])
b2 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)])
b3 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)])
b_nonsq = BlockDiagonal([rand(rng, N1, N2), rand(rng, N2, N1)])
A = rand(rng, N, N + N1)
B = rand(rng, N + N1, N + N2)
A′, B′ = A', B'
Expand All @@ -36,8 +37,12 @@ end
end

@testset "Unary Linear Algebra" begin
nonsquare = (adjoint, diag, pinv, svdvals, transpose)
@testset "$f" for f in (adjoint, det, diag, eigvals, inv, pinv, svdvals, transpose, tr)
@test f(b1) f(Matrix(b1))
if f in nonsquare
@test f(b_nonsq) f(Matrix(b_nonsq))
end
end

@testset "permute=$p, scale=$s" for p in (true, false), s in (true, false)
Expand Down

0 comments on commit 3b4e37b

Please sign in to comment.