-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some Array rules #491
Some Array rules #491
Conversation
Codecov Report
@@ Coverage Diff @@
## master #491 +/- ##
==========================================
+ Coverage 98.04% 98.06% +0.01%
==========================================
Files 22 22
Lines 2306 2321 +15
==========================================
+ Hits 2261 2276 +15
Misses 45 45
Continue to review full report at Codecov.
|
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@mcabbott @nickrobinson251 I've constrained the edit: (@nickrobinson251 I'll wait for you to approve before I merge, because I've modified the |
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Seems like all comments have now been resolved -- @mcabbott I've requested a re-review because a lot has changed since your original approval. |
) where {N} | ||
projects = map(ProjectTo, X) | ||
function vect_pullback(ȳ) | ||
X̄ = ntuple(n -> projects[n](ȳ[n]), N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering whether projects[n]
gets handled well. @code_warntype
seems happy... and my attempts to make things easier to unroll all make it slower:
julia> @btime rrule(Base.vect, 1,2,3)[2]($(rand(3)))
28.684 ns (1 allocation: 112 bytes)
(NoTangent(), 0.7437540971290453, 0.6835525631785602, 0.29678387383966687)
julia> @btime rrule(Base.vect, 1+im,2+im,3+im)[2]($(rand(3)))
235.140 ns (6 allocations: 416 bytes)
(NoTangent(), 0.9212083670665245 + 0.0im, 0.989459216141123 + 0.0im, 0.8454719840778347 + 0.0im)
julia> @btime rrule(Base.vect, 1, 2+3im, 4.0)[2]($(rand(3)))
609.760 ns (6 allocations: 320 bytes)
(NoTangent(), 0.2914057312235363, 0.23309219863512798 + 0.0im, 0.08023319383991401)
julia> struct StaticGetter{i} end; @inline (::StaticGetter{i})(v) where {i} = v[i]; # from Zygote
julia> function rrule(
::typeof(Base.vect),
X::Vararg{Union{Number,AbstractArray{<:Number}}, N},
) where {N}
valN = Val(N)
projects = map(ProjectTo, X)
function vect_pullback(ȳ)
X̄ = ntuple(n -> StaticGetter{n}()(projects)(ȳ[n]), valN)
return (NoTangent(), X̄...)
end
return Base.vect(X...), vect_pullback
end
rrule (generic function with 723 methods)
julia> @btime rrule(Base.vect, 1, 2+3im, 4.0)[2]($(rand(3)))
1.442 μs (11 allocations: 448 bytes)
(NoTangent(), 0.7692288886268137, 0.39993377443044065 + 0.0im, 0.6341039234276757)
Added one additional test -- will merge once CI passes. |
No description provided.