Skip to content

Commit 4ee4ef5

Browse files
authored
@non_differentiable _denom (#687)
* non_differentiable _denom * Update Project.toml * ∇getindex(::AbstractZero) paths * non_differentiable foreach(f, ()) * also fix summary
1 parent 9a405f7 commit 4ee4ef5

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.46.0"
3+
version = "1.46.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/indexing.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,9 @@ function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...)
6161
end
6262

6363
function rrule(::typeof(getindex), x::AbstractArray, inds...)
64-
function getindex_pullback(dy)
65-
nots = map(Returns(NoTangent()), inds)
66-
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
67-
end
64+
nots = map(Returns(NoTangent()), inds)
65+
getindex_pullback(dy) = (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
66+
getindex_pullback(z::AbstractZero) = (NoTangent(), z, nots...)
6867
return x[inds...], getindex_pullback
6968
end
7069

@@ -90,6 +89,7 @@ function ∇getindex(x::AbstractArray, dy, inds...)
9089
∇getindex!(dx, dy, plain_inds...)
9190
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
9291
end
92+
∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z
9393

9494
"""
9595
_setindex_zero(x, dy, inds...)
@@ -191,10 +191,9 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...)
191191
end
192192

193193
function rrule(::typeof(view), x::AbstractArray, inds...)
194-
function view_pullback(dy)
195-
nots = map(Returns(NoTangent()), inds)
196-
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
197-
end
194+
nots = map(Returns(NoTangent()), inds)
195+
view_pullback(dy) = (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
196+
view_pullback(z::AbstractZero) = (NoTangent(), z, nots...)
198197
return view(x, inds...), view_pullback
199198
end
200199

src/rulesets/Base/nondiff.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@
189189
@non_differentiable floatmax(::Any)
190190
@non_differentiable floatmin(::Any)
191191
@non_differentiable flush(::Any)
192+
@non_differentiable foreach(::Any, ::Tuple{})
192193

193194
@non_differentiable gensym(::Symbol)
194195
@non_differentiable gensym(::String...)
@@ -422,6 +423,7 @@ end
422423
@non_differentiable supertype(::Any)
423424
@non_differentiable Symbol(::Any...)
424425
@non_differentiable symlink(::AbstractString, ::AbstractString)
426+
@non_differentiable summary(::Any)
425427

426428
@non_differentiable take!(::Base.GenericIOBuffer)
427429
@non_differentiable take!(::IOStream)
@@ -472,6 +474,7 @@ elseif isdefined(Base, :cumulative_compile_time_ns)
472474
end
473475
@non_differentiable Base.time_print(::Any...)
474476
@non_differentiable Base.OneTo(::Any...)
477+
@non_differentiable Base.array_summary(::Any)
475478

476479
@non_differentiable Broadcast.combine_styles(::Any...)
477480
@non_differentiable Broadcast.result_style(::Any)

src/rulesets/Statistics/statistics.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ _denom(x, dims) = size(x, dims)
66
_denom(x, dims::Colon) = length(x)
77
_denom(x, dims::Union{Tuple, AbstractArray}) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1)
88

9+
@non_differentiable _denom(::Any, ::Any) # else Zygote tries to AD unique(::Tuple)
10+
911
function rrule(::typeof(mean), x::AbstractArray{<:Union{Real,Complex,AbstractArray}}; dims=:)
1012
y_sum, sum_pullback = rrule(sum, x; dims)
1113
n = _denom(x, dims)

0 commit comments

Comments
 (0)