Skip to content

Commit

Permalink
Add fallback for +(::BlockDiagonal,::Diagonal) when blocks are not …
Browse files Browse the repository at this point in the history
…square (#120)

* Add fallback to `+(::BlockDiagonal,::Diagonal)` for nonsquare blocks

* Add tests

* Bump version
  • Loading branch information
mjp98 authored Nov 2, 2022
1 parent 960aa87 commit 10f2c9f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockDiagonals"
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.37"
version = "0.1.38"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
5 changes: 4 additions & 1 deletion src/base_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ function Base.:+(B::BlockDiagonal, M::StridedMatrix)
return A
end

function Base.:+(B::BlockDiagonal, M::Diagonal)::BlockDiagonal
function Base.:+(B::BlockDiagonal, M::Diagonal)
size(B) == size(M) || throw(DimensionMismatch("dimensions must match"))
if !all(is_square, blocks(B))
return Matrix(B) + M # Fallback on the generic Base method
end
A = copy(B)
d = diag(M)
row = 1
Expand Down
15 changes: 13 additions & 2 deletions test/base_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ using Test
b64 = BlockDiagonal([rand(rng, 2, 2), rand(rng, 2, 2)])
b32 = BlockDiagonal([rand(rng, Float32, 2, 2), rand(rng, Float32, 2, 2)])

bns = BlockDiagonal([rand(rng, N1, N2), rand(rng, N2, N3), rand(rng, N3, N1)])

@testset "Addition" begin
@testset "BlockDiagonal + BlockDiagonal" begin
@test b1 + b1 isa BlockDiagonal
Expand Down Expand Up @@ -58,6 +60,15 @@ using Test
@test D + b1 isa BlockDiagonal
@test D + b1 == D + Matrix(b1)
@test_throws DimensionMismatch D′ + b1

# Non-square blocks
@test D + bns isa Matrix
@test D + bns == D + Matrix(bns)
@test_throws DimensionMismatch D′ + bns

@test bns + D isa Matrix
@test bns + D == D + Matrix(bns)
@test_throws DimensionMismatch bns + D′
end

@testset "BlockDiagonal + UniformScaling" begin
Expand All @@ -69,11 +80,11 @@ using Test
@test 5I + b1 == 5I + Matrix(b1)
end
end # Addition

@testset "Subtraction" begin
@test -b1 isa BlockDiagonal
@test b1 - b1 isa BlockDiagonal

@test -b1 == -Matrix(b1)
@test b1 - b1 == Matrix(b1) - Matrix(b1)
@test Matrix(b1) - b2 == Matrix(b1) - Matrix(b2)
Expand Down

2 comments on commit 10f2c9f

@mjp98
Copy link
Contributor Author

@mjp98 mjp98 commented on 10f2c9f Nov 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/71505

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.38 -m "<description of version>" 10f2c9f3469b79d11bfc015b083501821af68e7e
git push origin v0.1.38

Please sign in to comment.