Skip to content

Commit

Permalink
Remove frule for getindex(::Tuple, i) (#680)
Browse files Browse the repository at this point in the history
* Remove frule for getindex(::Tuple, i)

Having this chain rule is sub-optimal, because it prevents early-SROA
in Diffractor-like systems that would like to perform some optimizations
before applying AD (but can't do any optimization on functions that
have custom rules). By letting it go down to the `getfield`, regular
SROA can apply. Any AD system should handle `getfield` anyway, so
I don't think there's a strong reason to have this.

Similar reasoning applies to the reverse rules also, but they
aren't currently actively causing me problems, so this PR only
removes the frule, since I don't think many other packages are
using them. We can revisit the rrules later.

* Also remove the rules for first/tail

For similar reasons as getindex, having a rule for first/tail is
suboptimal because it supresses early SROA. Tail is particularly
problematic, because it is used in the implementation of the
```
x, y... = abc
```
syntax, of which users expect early elimination.

* add getfield rule and remove tests for deleted rules

---------

Co-authored-by: Oscar Smith <oscardssmith@gmail.com>
  • Loading branch information
Keno and oscardssmith authored Jul 4, 2023
1 parent 859f6ab commit 05ebb38
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 57 deletions.
39 changes: 4 additions & 35 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
#####
##### getindex(::Tuple)
#####

function frule((_, ẋ), ::typeof(getindex), x::Tuple, i::Integer)
return x[i], ẋ[i]
end

function frule((_, ẋ), ::typeof(getindex), x::Tuple, i)
y = x[i]
return y, Tangent{typeof(y)}(ẋ[i]...)
# Int rather than Int64/Integer is intentional
function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int)
return x.i, ẋ.i
end

"for a given tuple type, returns a Val{N} where N is the length of the tuple"
Expand Down Expand Up @@ -77,7 +69,7 @@ end
"""
∇getindex(x, dy, inds...)
For the `rrule` of `y = x[inds...]`, this function is roughly
For the `rrule` of `y = x[inds...]`, this function is roughly
`setindex(zero(x), dy, inds...)`, returning the array `dx`.
Differentiable. Includes `ProjectTo(x)(dx)`.
"""
Expand Down Expand Up @@ -191,29 +183,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds...)
return dx
end

#####
##### first, tail
#####

function frule((_, ẋ), ::typeof(first), x::Tuple)
return first(x), first(ẋ)
end

function rrule(::typeof(first), x::T) where {T<:Tuple}
first_back(dy) = (NoTangent(), Tangent{T}(ntuple(j -> j == 1 ? dy : NoTangent(), _tuple_N(T))...))
return first(x), first_back
end

function frule((_, ẋ), ::typeof(Base.tail), x::Tuple)
y = Base.tail(x)
return y, Tangent{typeof(y)}(Base.tail(ẋ)...)
end

function rrule(::typeof(Base.tail), x::T) where {T<:Tuple}
tail_pullback(dy) = (NoTangent(), Tangent{T}(NoTangent(), dy...))
return Base.tail(x), tail_pullback
end

#####
##### view
#####
Expand Down
24 changes: 2 additions & 22 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
x = (1.2, 3.4, 5.6)
x2 = (rand(2), (a=1.0, b=x))

# Forward
test_frule(getindex, x, 2)
test_frule(getindex, x2, 1)
test_frule(getindex, x, 1:2)
test_frule(getindex, x2, :)

# don't test Forward because this will be handled by lowering to getfield
# Reverse
test_rrule(getindex, x, 2)
@test_skip test_rrule(getindex, x2, 1, check_inferred=false) # method ambiguity, maybe fixed by https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/253
Expand Down Expand Up @@ -168,22 +163,7 @@
end
end

@testset "first & tail" begin
x = (1.2, 3.4, 5.6)
x2 = (rand(2), (a=1.0, b=x))

test_frule(first, x)
test_frule(first, x2)

test_rrule(first, x)
# test_rrule(first, x2) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::NoTangent, ::Tangent{NamedTuple{(:a, :b), Tuple{Float64, Tuple{Float64, Float64, Float64}}}, NamedTuple{(:a, :b), Tuple{Float64, Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}}, ::String) is ambiguous

test_frule(Base.tail, x, check_inferred=false) # return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}} does not match inferred return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}}}
test_frule(Base.tail, x2, check_inferred=false)

test_rrule(Base.tail, x)
test_rrule(Base.tail, x2)
end
# first & tail handled by getfield rules

@testset "view" begin
test_frule(view, rand(3, 4), :, 1)
Expand Down

0 comments on commit 05ebb38

Please sign in to comment.