Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sparse rrule #579

Merged
merged 7 commits into from
Jan 31, 2022
Merged

add sparse rrule #579

merged 7 commits into from
Jan 31, 2022

Conversation

CarloLucibello
Copy link
Contributor

@CarloLucibello CarloLucibello commented Jan 22, 2022

Partial replacement for #246

cc @sethaxen

@CarloLucibello CarloLucibello changed the title add sparse(I, J, V, m, n, +) rrule add sparse rrule Jan 22, 2022
@CarloLucibello CarloLucibello changed the title add sparse rrule add sparse rrule Jan 22, 2022

function rrule(::typeof(sparse), A::Union{AbstractVector, AbstractMatrix})
function sparse_pullback(Ω̄)
return NoTangent(), Ω̄
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this ok or we need to project or something?

return sparse(I, J, V, m, n, combine), sparse_pullback
end

function rrule(::typeof(sparse), A::Union{AbstractVector, AbstractMatrix})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these rules be attached to sparse or one step later to things like Type{<:SparseMatrixCSC}? I don't have examples but catching all as many paths like T(A) and convert(T, A) etc. as possible sounds desirable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also need rules for SparseMatrixCSC, but for sure we need a rule for sparse since it calls sparse! for the heavy-duty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I meant for the 1-arg sparse. The one which takes all the vectors is a different story.

Copy link
Contributor Author

@CarloLucibello CarloLucibello Jan 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, didn't notice what you were pointing at. I'm not super confident in defining rrules for constructors.
Should it be rrule(::Type{<:SparseMatrixCSC},...) or rrule(::Type{<:SparseMatrixCSC{Tv,Ti},...) or the two are equivalent?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fiddle until it works, but my guess is that the first would capture all mode specific cases. Maybe not convert though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I defined the rrules for the type and removed this one. I did check separately that Zygote's gradients of sparse(A) and sparse(v) go through

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants