Skip to content

Commit

Permalink
Improve tests in blockdiagonal.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
nickrobinson251 committed Feb 15, 2021
1 parent 9bf2360 commit 5477835
Showing 1 changed file with 18 additions and 21 deletions.
39 changes: 18 additions & 21 deletions test/blockdiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@ end
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]

@testset "$T" for (T, (b1, b2, b3)) in (
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
)
A = rand(rng, N, N + N1)
B = rand(rng, N + N1, N + N2)
A′, B′ = A', B'
a = rand(rng, N)
b = rand(rng, N + N1)
@testset for V in (Tuple, Vector)
b1 = BlockDiagonal(V(blocks1))
b2 = BlockDiagonal(V(blocks2))
N = size(b1, 1)

@testset "AbstractArray" begin
X = rand(2, 2); Y = rand(3, 3)
X = rand(2, 2)
Y = rand(3, 3)

@test size(b1) == (N, N)
@test size(b1, 1) == N && size(b1, 2) == N
Expand All @@ -53,7 +49,7 @@ end
end

@testset "parent" begin
@test parent(b1) isa Union{Tuple,AbstractVector}
@test parent(b1) isa V
@test eltype(parent(b1)) <: AbstractMatrix
@test parent(BlockDiagonal([X, Y])) == [X, Y]
@test parent(BlockDiagonal((X, Y))) == (X, Y)
Expand All @@ -66,7 +62,7 @@ end
end

@testset "setindex!" begin
X = BlockDiagonal([rand(Float32, 5, 5), rand(Float32, 3, 3)])
X = BlockDiagonal(V([rand(Float32, 5, 5), rand(Float32, 3, 3)]))
X[10] = Int(10)
@test X[10] === Float32(10.0)
X[3, 3] = Int(9)
Expand All @@ -78,14 +74,15 @@ end

@testset "ChainRules" begin
@testset "BlockDiagonal" begin
x = [randn(1, 2), randn(2, 2)]
= [randn(1, 2), randn(2, 2)]
= Composite{typeof(BlockDiagonal(x))}(blocks=[randn(1, 2), randn(2, 2)])
x = V([randn(1, 2), randn(2, 2)])
= V([randn(1, 2), randn(2, 2)])

= Composite{typeof(BlockDiagonal(x))}(blocks=V([randn(1, 2), randn(2, 2)]))
rrule_test(BlockDiagonal, ȳ, (x, x̄))
end
@testset "Matrix" begin
D = BlockDiagonal([randn(1, 2), randn(2, 2)])
= Composite{typeof(D)}((blocks=[randn(1, 2), randn(2, 2)]), )
D = BlockDiagonal(V([randn(1, 2), randn(2, 2)]))
= Composite{typeof(D)}((blocks=V([randn(1, 2), randn(2, 2)])),)
= randn(size(D))
rrule_test(Matrix, Ȳ, (D, D̄))
end
Expand All @@ -98,9 +95,9 @@ end
end

@testset "blocks size" begin
B = BlockDiagonal([rand(3, 3), rand(4, 4)])
B = BlockDiagonal(V([rand(3, 3), rand(4, 4)]))
@test nblocks(B) == 2
@test blocksizes(B) == [(3, 3), (4, 4)]
@test blocksizes(B) == V([(3, 3), (4, 4)])
@test blocksize(B, 2) == blocksizes(B)[2] == blocksize(B, 2, 2)
end

Expand All @@ -124,8 +121,8 @@ end
@testset "Non-Square Matrix" begin
A1 = ones(2, 4)
A2 = 2 * ones(3, 2)
B1 = BlockDiagonal([A1, A2])
B2 = [A1 zeros(2, 2); zeros(3, 4) A2]
B1 = BlockDiagonal(V([A1, A2]))
B2 = [A1 zeros(2, 2); zeros(3, 4) A2]

@test B1 == B2
# Dimension check
Expand Down

0 comments on commit 5477835

Please sign in to comment.