Skip to content

Commit

Permalink
Merge pull request #41 from invenia/wct/chainrules
Browse files Browse the repository at this point in the history
Add some ChainRules
  • Loading branch information
oxinabox authored Dec 2, 2020
2 parents 8e45cdc + dc8a36a commit acd5cb8
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 3 deletions.
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
name = "BlockDiagonals"
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.6"
version = "0.1.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
FillArrays = "0.6, 0.7, 0.8"
ChainRulesCore = "0.9"
FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10"
julia = "1"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Documenter", "Random", "Test"]
test = ["ChainRulesTestUtils", "Documenter", "FiniteDifferences", "Random", "Test"]
1 change: 1 addition & 0 deletions src/BlockDiagonals.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module BlockDiagonals

using Base: @propagate_inbounds
using ChainRulesCore
using FillArrays: Zeros
using LinearAlgebra

Expand Down
23 changes: 23 additions & 0 deletions src/blockdiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ function BlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
return BlockDiagonal{T, V}(blocks)
end

function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::Vector{V}) where {V}
BlockDiagonal_pullback::Composite) = (NO_FIELDS, Δ.blocks)
return BlockDiagonal(blocks), BlockDiagonal_pullback
end

BlockDiagonal(B::BlockDiagonal) = B

is_square(A::AbstractMatrix) = size(A, 1) == size(A, 2)
Expand Down Expand Up @@ -102,6 +107,24 @@ end

## Base
Base.Matrix(B::BlockDiagonal) = cat(blocks(B)...; dims=(1, 2))

function ChainRulesCore.rrule(::Type{<:Base.Matrix}, B::T) where {T<:BlockDiagonal}
nrows = size.(B.blocks, 1)
ncols = size.(B.blocks, 2)
function Matrix_pullback::Matrix)
row_idxs = cumsum(nrows) .- nrows .+ 1
col_idxs = cumsum(ncols) .- ncols .+ 1

Δblocks = map(eachindex(nrows)) do n
block_rows = row_idxs[n]:(row_idxs[n] + nrows[n] - 1)
block_cols = col_idxs[n]:(col_idxs[n] + ncols[n] - 1)
return Δ[block_rows, block_cols]
end
return (NO_FIELDS, Composite{T}(blocks=Δblocks))
end
return Matrix(B), Matrix_pullback
end

Base.size(B::BlockDiagonal) = sum(firstsize, blocks(B)), sum(lastsize, blocks(B))
Base.similar(B::BlockDiagonal) = BlockDiagonal(map(similar, blocks(B)))
Base.parent(B::BlockDiagonal) = B.blocks
Expand Down
25 changes: 25 additions & 0 deletions test/blockdiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ using BlockDiagonals: isequal_blocksizes
using Random
using Test

function FiniteDifferences.to_vec(X::BlockDiagonal)
x, blocks_from_vec = to_vec(X.blocks)
BlockDiagonal_from_vec(x_vec) = BlockDiagonal(blocks_from_vec(x_vec))
return x, BlockDiagonal_from_vec
end

function Base.isapprox(C::Composite{<:BlockDiagonal}, D::BlockDiagonal; kwargs...)
return isapprox(C.blocks, D.blocks; kwargs...)
end

@testset "blockdiagonal.jl" begin
rng = MersenneTwister(123456)
N1, N2, N3 = 3, 4, 5
Expand Down Expand Up @@ -61,6 +71,21 @@ using Test
end
end # AbstractArray

@testset "ChainRules" begin
@testset "BlockDiagonal" begin
x = [randn(1, 2), randn(2, 2)]
= [randn(1, 2), randn(2, 2)]
= Composite{typeof(BlockDiagonal(x))}(blocks=[randn(1, 2), randn(2, 2)])
rrule_test(BlockDiagonal, ȳ, (x, x̄))
end
@testset "Matrix" begin
D = BlockDiagonal([randn(1, 2), randn(2, 2)])
= Composite{typeof(D)}((blocks=[randn(1, 2), randn(2, 2)]), )
= randn(size(D))
rrule_test(Matrix, Ȳ, (D, D̄))
end
end

@testset "isequal_blocksizes" begin
@test isequal_blocksizes(b1, b1) == true
@test isequal_blocksizes(b1, similar(b1)) == true
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using BlockDiagonals
using ChainRulesCore
using ChainRulesTestUtils
using Documenter
using FiniteDifferences # For overloading to_vec
using Test

@testset "BlockDiagonals" begin
Expand Down

2 comments on commit acd5cb8

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/25691

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.7 -m "<description of version>" acd5cb830a1795aedca31ab5f41754ab98b5c6f1
git push origin v0.1.7

Please sign in to comment.