Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic sparse handling #762

Closed
wants to merge 19 commits into from
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand Down
30 changes: 30 additions & 0 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Random, FillArrays, AbstractFFTs
using FillArrays: AbstractFill, getindex_value
using Base.Broadcast: broadcasted, broadcast_shape
using SparseArrays
using Distributed: pmap, AbstractWorkerPool

@adjoint (::Type{T})(::UndefInitializer, args...) where T<:Array = T(undef, args...), Δ -> nothing
Expand Down Expand Up @@ -957,3 +958,32 @@ end
back(Δ::AbstractArray) = (nothing, getindex.(_back.(Δ), 1))
return Fill(y, size(r)), back
end

# Sparse Arrays

@adjoint function SparseMatrixCSC{T,N}(arr) where {T,N}
SparseMatrixCSC{T,N}(arr), Δ -> (collect(Δ),)
end

@adjoint function SparseVector{T,N}(v) where {T,N}
SparseVector{T,N}(v), Δ -> (collect(Δ),)
end

@adjoint diagm(x::AbstractSparseArray) = diagm(x), Δ -> (diag(Δ), )

@adjoint function Broadcast.broadcasted(::Type{Float32}, a::AbstractSparseArray{T,N}) where {T,N}
Float32.(a), Δ -> (nothing, T.(Δ), )
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
end

@adjoint Matrix(a::AbstractSparseArray) = Matrix(a), Δ -> (Δ,)

@adjoint function SparseArrays.spdiagm(x::Pair...)
ks = first.(x)
SparseArrays.spdiagm(x...), Δ -> begin
tuple((k => diag(Δ, k) for k in ks)...)
end
end

@adjoint Pair(x,y) = Pair(x,y), Δ -> (nothing, Δ.second)

@nograd issymmetric
8 changes: 8 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using Zygote, Test, Random, LinearAlgebra, Statistics, FillArrays,
AbstractFFTs, FFTW, Distances
using Zygote: gradient
using Base.Broadcast: broadcast_shape
using Distributed: pmap
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this import?

using SparseArrays
using Distributed: pmap, CachingPool, workers
import FiniteDifferences

Expand Down Expand Up @@ -1571,6 +1573,12 @@ end
end
end

@testset "Sparse" begin
@test gradtest(x -> sum(sparse(x)), rand(Float32, 3,3))
@test gradtest(x -> sum(sparse(x)), rand(Float32, 3)) # test vectors also
@test gradcheck(x -> sum(diagm(x)), sparse(rand(3)))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can add a test for broadcasting


@testset "broadcasted($op, Array, Bool)" for op in (+,-,*)
@testset "with $bool and sizes $s" for s in ((4,), (2,3)), bool in (true,false)
r = rand(Int8, s) .+ 0.0
Expand Down