diff --git a/Project.toml b/Project.toml index b8526d216..dab6288db 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 1b5eba024..66dda0909 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -8,6 +8,7 @@ using LinearAlgebra using LinearAlgebra.BLAS using Random using RealDot: realdot +using SparseArrays using Statistics # Basically everything this package does is overloading these, so we make an exception @@ -43,6 +44,8 @@ include("rulesets/LinearAlgebra/symmetric.jl") include("rulesets/LinearAlgebra/factorization.jl") include("rulesets/LinearAlgebra/uniformscaling.jl") +include("rulesets/SparseArrays/sparsematrix.jl") + include("rulesets/Random/random.jl") end # module diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl new file mode 100644 index 000000000..22a5d0366 --- /dev/null +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -0,0 +1,25 @@ +function rrule(::typeof(sparse), I::AbstractVector, J::AbstractVector, V::AbstractVector, m, n, combine::typeof(+)) + project_V = ProjectTo(V) + + function sparse_pullback(Ω̄) + ΔΩ = unthunk(Ω̄) + ΔV = project_V(ΔΩ[I .+ m .* (J .- 1)]) + return NoTangent(), NoTangent(), NoTangent(), ΔV, NoTangent(), NoTangent(), NoTangent() + end + + return sparse(I, J, V, m, n, combine), sparse_pullback +end + +function rrule(::Type{T}, A::AbstractMatrix) where T <: SparseMatrixCSC + function sparse_pullback(Ω̄) + return NoTangent(), Ω̄ + end + return T(A), sparse_pullback +end + +function rrule(::Type{T}, v::AbstractVector) where T <: SparseVector + function sparse_pullback(Ω̄) + return NoTangent(), Ω̄ + end + return T(v), sparse_pullback +end diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl new file mode 100644 index 000000000..3e239cb1c --- /dev/null +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -0,0 +1,19 @@ + +@testset "sparse(I, J, V, m, n, +)" begin + m, n = 3, 5 + s, t, w = [1,2], [2,3], [0.5,0.5] + + test_rrule(sparse, s, t, w, m, n, +) +end + +@testset "SparseMatrixCSC(A)" begin + A = rand(5, 3) + test_rrule(SparseMatrixCSC, A) + test_rrule(SparseMatrixCSC{Float32,Int}, A, rtol=1e-5) +end + +@testset "SparseVector(v)" begin + v = rand(5) + test_rrule(SparseVector, v) + test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-5) +end diff --git a/test/runtests.jl b/test/runtests.jl index 201cb406f..24c1d85b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,7 @@ using LinearAlgebra using LinearAlgebra.BLAS using LinearAlgebra: dot using Random +using SparseArrays using StaticArrays using Statistics using Test @@ -75,6 +76,10 @@ end println() + include_test("rulesets/SparseArrays/sparsematrix.jl") + + println() + include_test("rulesets/Random/random.jl") println() end