Skip to content

Commit

Permalink
Update rrules
Browse files Browse the repository at this point in the history
  • Loading branch information
nickrobinson251 committed Feb 15, 2021
1 parent 80d1e1a commit 70ddc17
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function ChainRulesCore.rrule(
::typeof(*),
bm::BlockDiagonal{T, V},
v::StridedVector{T}
) where {T<:Union{Real, Complex}, V<:Matrix{T}}
) where {T<:Union{Real, Complex}, V}

y = bm * v

Expand Down
26 changes: 14 additions & 12 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
@testset "chainrules.jl" begin
@testset "BlockDiagonal" begin
x = [randn(1, 2), randn(2, 2)]
test_rrule(BlockDiagonal, x)
end
@testset for V in (Tuple, Vector)
@testset "BlockDiagonal" begin
x = V([randn(1, 2), randn(2, 2)])
test_rrule(BlockDiagonal, x)
end

@testset "Matrix" begin
D = BlockDiagonal([randn(1, 2), randn(2, 2)])
test_rrule(Matrix, D)
end
@testset "Matrix" begin
B = BlockDiagonal(V([randn(1, 2), randn(2, 2)]))
test_rrule(Matrix, B)
end

@testset "BlockDiagonal * Vector" begin
D = BlockDiagonal([rand(2, 3), rand(3, 3)])
v = rand(6)
test_rrule(*, D, v)
@testset "BlockDiagonal * Vector" begin
B = BlockDiagonal(V([rand(2, 3), rand(3, 3)]))
v = rand(6)
test_rrule(*, B, v)
end
end
end

0 comments on commit 70ddc17

Please sign in to comment.