diff --git a/Project.toml b/Project.toml index de33516..2f2c697 100644 --- a/Project.toml +++ b/Project.toml @@ -1,20 +1,24 @@ name = "BlockDiagonals" uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.6" +version = "0.1.7" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] -FillArrays = "0.6, 0.7, 0.8" +ChainRulesCore = "0.9" +FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10" julia = "1" [extras] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Documenter", "Random", "Test"] +test = ["ChainRulesTestUtils", "Documenter", "FiniteDifferences", "Random", "Test"] diff --git a/src/BlockDiagonals.jl b/src/BlockDiagonals.jl index b6af37e..1d5aaf4 100644 --- a/src/BlockDiagonals.jl +++ b/src/BlockDiagonals.jl @@ -1,6 +1,7 @@ module BlockDiagonals using Base: @propagate_inbounds +using ChainRulesCore using FillArrays: Zeros using LinearAlgebra diff --git a/src/blockdiagonal.jl b/src/blockdiagonal.jl index c1dc34d..99301c7 100644 --- a/src/blockdiagonal.jl +++ b/src/blockdiagonal.jl @@ -17,6 +17,11 @@ function BlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}} return BlockDiagonal{T, V}(blocks) end +function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::Vector{V}) where {V} + BlockDiagonal_pullback(Δ::Composite) = (NO_FIELDS, Δ.blocks) + return BlockDiagonal(blocks), BlockDiagonal_pullback +end + BlockDiagonal(B::BlockDiagonal) = B is_square(A::AbstractMatrix) = size(A, 1) == size(A, 2) @@ -102,6 +107,24 @@ end ## Base Base.Matrix(B::BlockDiagonal) = cat(blocks(B)...; dims=(1, 2)) + +function ChainRulesCore.rrule(::Type{<:Base.Matrix}, B::T) where {T<:BlockDiagonal} + nrows = size.(B.blocks, 1) + ncols = size.(B.blocks, 2) + function Matrix_pullback(Δ::Matrix) + row_idxs = cumsum(nrows) .- nrows .+ 1 + col_idxs = cumsum(ncols) .- ncols .+ 1 + + Δblocks = map(eachindex(nrows)) do n + block_rows = row_idxs[n]:(row_idxs[n] + nrows[n] - 1) + block_cols = col_idxs[n]:(col_idxs[n] + ncols[n] - 1) + return Δ[block_rows, block_cols] + end + return (NO_FIELDS, Composite{T}(blocks=Δblocks)) + end + return Matrix(B), Matrix_pullback +end + Base.size(B::BlockDiagonal) = sum(first∘size, blocks(B)), sum(last∘size, blocks(B)) Base.similar(B::BlockDiagonal) = BlockDiagonal(map(similar, blocks(B))) Base.parent(B::BlockDiagonal) = B.blocks diff --git a/test/blockdiagonal.jl b/test/blockdiagonal.jl index 966f626..8a67c5e 100644 --- a/test/blockdiagonal.jl +++ b/test/blockdiagonal.jl @@ -3,6 +3,16 @@ using BlockDiagonals: isequal_blocksizes using Random using Test +function FiniteDifferences.to_vec(X::BlockDiagonal) + x, blocks_from_vec = to_vec(X.blocks) + BlockDiagonal_from_vec(x_vec) = BlockDiagonal(blocks_from_vec(x_vec)) + return x, BlockDiagonal_from_vec +end + +function Base.isapprox(C::Composite{<:BlockDiagonal}, D::BlockDiagonal; kwargs...) + return isapprox(C.blocks, D.blocks; kwargs...) +end + @testset "blockdiagonal.jl" begin rng = MersenneTwister(123456) N1, N2, N3 = 3, 4, 5 @@ -61,6 +71,21 @@ using Test end end # AbstractArray + @testset "ChainRules" begin + @testset "BlockDiagonal" begin + x = [randn(1, 2), randn(2, 2)] + x̄ = [randn(1, 2), randn(2, 2)] + ȳ = Composite{typeof(BlockDiagonal(x))}(blocks=[randn(1, 2), randn(2, 2)]) + rrule_test(BlockDiagonal, ȳ, (x, x̄)) + end + @testset "Matrix" begin + D = BlockDiagonal([randn(1, 2), randn(2, 2)]) + D̄ = Composite{typeof(D)}((blocks=[randn(1, 2), randn(2, 2)]), ) + Ȳ = randn(size(D)) + rrule_test(Matrix, Ȳ, (D, D̄)) + end + end + @testset "isequal_blocksizes" begin @test isequal_blocksizes(b1, b1) == true @test isequal_blocksizes(b1, similar(b1)) == true diff --git a/test/runtests.jl b/test/runtests.jl index 38a2cb3..8557958 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,8 @@ using BlockDiagonals +using ChainRulesCore +using ChainRulesTestUtils using Documenter +using FiniteDifferences # For overloading to_vec using Test @testset "BlockDiagonals" begin