Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some Array rules #491

Merged
merged 24 commits into from
Aug 2, 2021
Merged

Some Array rules #491

merged 24 commits into from
Aug 2, 2021

Conversation

willtebbutt
Copy link
Member

No description provided.

@willtebbutt willtebbutt changed the title Some Array rules WIP: Some Array rules Aug 1, 2021
@codecov-commenter
Copy link

codecov-commenter commented Aug 1, 2021

Codecov Report

Merging #491 (000b33f) into master (36508af) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #491      +/-   ##
==========================================
+ Coverage   98.04%   98.06%   +0.01%     
==========================================
  Files          22       22              
  Lines        2306     2321      +15     
==========================================
+ Hits         2261     2276      +15     
  Misses         45       45              
Impacted Files Coverage Δ
src/rulesets/Base/array.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 36508af...000b33f. Read the comment docs.

@willtebbutt willtebbutt mentioned this pull request Aug 1, 2021
@willtebbutt
Copy link
Member Author

willtebbutt commented Aug 1, 2021

@mcabbott @nickrobinson251 I've constrained the vect implementation a bit, because I couldn't figure out how to implement the most generic (and most useful) pullback. I've opened #492 to discuss further.

edit: (@nickrobinson251 I'll wait for you to approve before I merge, because I've modified thevect implementation a bit since you last looked at it)

src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
test/rulesets/Base/array.jl Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
willtebbutt and others added 3 commits August 1, 2021 22:43
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
willtebbutt and others added 5 commits August 2, 2021 13:13
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@willtebbutt
Copy link
Member Author

Seems like all comments have now been resolved -- @mcabbott I've requested a re-review because a lot has changed since your original approval.

) where {N}
projects = map(ProjectTo, X)
function vect_pullback(ȳ)
X̄ = ntuple(n -> projects[n](ȳ[n]), N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering whether projects[n] gets handled well. @code_warntype seems happy... and my attempts to make things easier to unroll all make it slower:

julia> @btime rrule(Base.vect, 1,2,3)[2]($(rand(3)))
  28.684 ns (1 allocation: 112 bytes)
(NoTangent(), 0.7437540971290453, 0.6835525631785602, 0.29678387383966687)

julia> @btime rrule(Base.vect, 1+im,2+im,3+im)[2]($(rand(3)))
  235.140 ns (6 allocations: 416 bytes)
(NoTangent(), 0.9212083670665245 + 0.0im, 0.989459216141123 + 0.0im, 0.8454719840778347 + 0.0im)

julia> @btime rrule(Base.vect, 1, 2+3im, 4.0)[2]($(rand(3)))
  609.760 ns (6 allocations: 320 bytes)
(NoTangent(), 0.2914057312235363, 0.23309219863512798 + 0.0im, 0.08023319383991401)

julia> struct StaticGetter{i} end; @inline (::StaticGetter{i})(v) where {i} = v[i]; # from Zygote

julia> function rrule(
            ::typeof(Base.vect),
            X::Vararg{Union{Number,AbstractArray{<:Number}}, N},
        ) where {N}
            valN = Val(N)
            projects = map(ProjectTo, X)
            function vect_pullback(ȳ)
                X̄ = ntuple(n -> StaticGetter{n}()(projects)(ȳ[n]), valN)
                return (NoTangent(), X̄...)
            end
            return Base.vect(X...), vect_pullback
        end
rrule (generic function with 723 methods)

julia> @btime rrule(Base.vect, 1, 2+3im, 4.0)[2]($(rand(3)))
  1.442 μs (11 allocations: 448 bytes)
(NoTangent(), 0.7692288886268137, 0.39993377443044065 + 0.0im, 0.6341039234276757)

@willtebbutt
Copy link
Member Author

Added one additional test -- will merge once CI passes.

@willtebbutt willtebbutt changed the title WIP: Some Array rules Some Array rules Aug 2, 2021
@willtebbutt willtebbutt merged commit 7593339 into master Aug 2, 2021
@willtebbutt willtebbutt deleted the wct/some-rules branch August 2, 2021 15:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants