Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds some missing unthunks #282

Merged
merged 6 commits into from
Oct 14, 2020
Merged

Adds some missing unthunks #282

merged 6 commits into from
Oct 14, 2020

Conversation

oxinabox
Copy link
Member

I am not 100% sure how i feel about this.
Maybe we should actually overload some more linear operators in ChainRulesCore.
Including in particular: the Diagonal constructor,
and \ with a Vector

But for svd_rev because we use some of the inputs multiple times, it is actually faster to unthunk them at the start anyway.
I feel like we shouldn't actually have to use any input more than once, but that is another issue.

The testset in Nabla that revealed this was:

    @testset "Tape updating from multiple components" begin
        ∇f = () do X
            U, S, V = svd(X)
            Y = U * Diagonal(S) * V'
            sum(Y)
        end
        X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
        @test ∇f(X)[1]  ones(3, 2) atol=1e-5
    end

I am not sure what is a good way to test this PR.
Do we want to add something to ChainRulesTestUtils that redoes all tests passing in Thunk's of the inputs in as well?
That wouldn't trivially catch Composites.

This is needed for invenia/Nabla.jl#189

src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/structured.jl Outdated Show resolved Hide resolved
@mzgubic
Copy link
Member

mzgubic commented Oct 13, 2020

What actually happened in Nabla with that testset, i.e. why did it not pass? I thought the only difference between the current implementation and the new one is that the unthunk(x) happens only once and not many times, and the value is reused. Or are there some side-effects happening under the hood in unthunk?

@oxinabox
Copy link
Member Author

oxinabox commented Oct 13, 2020

What actually happened in Nabla with that testset, i.e. why did it not pass?

It throws a MethodError,
both for the constructor, and then if that is fixed within svd_rev for \

@mzgubic
Copy link
Member

mzgubic commented Oct 13, 2020

Aha thanks, I think I now see what the comment on testing means. In my own words:

If someone defines their own rrule, how to provide a tool that tests whether it also works for thunks passed as sensitivities?

An rrule which returns function Diagonal_pullback(x) would work if a thunk is passed (but may not be most efficient). If for some reason a user defines a more restrictive argument type, say x::AbstractMatrix as above we want them to know it will not work with thunks.

The problem with Composites is that they are not themselves thunks, but they may contain thunks, so the simple test of thunking the inputs will not catch those.

Are such Composites a common case however? I.e. the rrule would have to return a pullback where Some_pulback(x::Composite{SpecificType}) while what would be passed is Composite{Thunk}, if I understand the issue correctly?

@oxinabox
Copy link
Member Author

Are such Composites a common case however? I.e. the rrule would have to return a pullback where Some_pulback(x::Composite{SpecificType}) while what would be passed is Composite{Thunk}, if I understand the issue correctly?

It isn't Composite{Thunk} it is Composite{SpecificType}.X isa Thunk
and it is very common.

@mzgubic
Copy link
Member

mzgubic commented Oct 13, 2020

It isn't Composite{Thunk} it is Composite{SpecificType}.X isa Thunk
and it is very common.

How does this break things though? (Not by MethodError, right?)

@oxinabox
Copy link
Member Author

By MethodError

It isn't Composite{Thunk} it is Composite{SpecificType}.X isa Thunk
and it is very common.

How does this break things though? (Not by MethodError, right?)

By method error when you try and do Diagonal(comp.F)
or comp.F / s

test/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
test/rulesets/LinearAlgebra/factorization.jl Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Show resolved Hide resolved
@oxinabox
Copy link
Member Author

The unthunk in Diagonal was a mistake in Nabla.
Nabla was supposed to be unthunking things before it input them into propagators, but it wasn't.
So we can remove the Diagonal one.
Probably doesn't help with the Composite one since nabla shouldn't be recursively unthunking.
JuliaDiff/ChainRulesCore.jl#121 would solve it.
But in this case we still want to unthunk in advance since things are used multiple times in svd_rev.
They probably shouldn't be but they are.

oxinabox and others added 3 commits October 14, 2020 15:59
Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
@oxinabox oxinabox merged commit bbe68cc into master Oct 14, 2020
@oxinabox oxinabox deleted the ox/unthunkgood branch October 16, 2020 14:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants