Skip to content

Commit

Permalink
Merge branch 'main' into kf/forwardchunk
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith authored Sep 29, 2022
2 parents 183d0be + 55d2871 commit 45f04ae
Show file tree
Hide file tree
Showing 13 changed files with 252 additions and 158 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
name: CI
on:
schedule:
- cron: '0 6 * * *' # Daily at 6 AM UTC (2 AM EST)
pull_request:
push:
branches:
Expand All @@ -14,11 +16,16 @@ jobs:
fail-fast: false
matrix:
version:
- '1.7' # Lowest claimed support in Project.toml
# - '1' # Latest Release # Testing on 1.8 gives this message:
# ┌ Warning: ir verification broken. Either use 1.9 or 1.7
# └ @ Diffractor ~/work/Diffractor.jl/Diffractor.jl/src/stage1/recurse.jl:889
- 'nightly'
os:
- ubuntu-latest
- macOS-latest
- windows-latest
# FIXME
# - windows-latest
arch:
- x64
steps:
Expand Down
14 changes: 2 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
ChainRules = "1.5"
ChainRulesCore = "1.2"
ChainRules = "1.44.6"
ChainRulesCore = "1.15.3"
Combinatorics = "1"
StaticArrays = "1"
StatsBase = "0.33"
StructArrays = "0.6"
julia = "1.7"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "ForwardDiff", "LinearAlgebra", "Random", "Symbolics"]
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

**Docs:**
[![](https://img.shields.io/badge/docs-master-blue.svg)](https://juliadiff.org/Diffractor.jl/dev)
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/Diffractor.jl/stable)

# General Overview

Expand Down
2 changes: 1 addition & 1 deletion docs/src/reading_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ many of these references are quite dense and though I've found small nuggets
of insight in each, excavating those took many hours. Also, these are not
introductory texts. If you've not taken an introductory differential
geometry course, I would recommend looking for that first. Don't feel bad if
some of these references read like like nonsense. It often reads that way to me to.
some of these references read like like nonsense. It often reads that way to me too.

# Reading on Optics

Expand Down
12 changes: 6 additions & 6 deletions docs/terminology/terminology.tex
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ \section{Optical Constructions}
on the representative - see Riley for details).
\end{definition}

This definition makes maniffest the combination of co- and contravariant data.
This definition makes manifest the combination of co- and contravariant data.
For a representative $\langle l | r \rangle$, $l$ varies covariantly while $r$
varies contravariantly. We additionally have a ``memory" or ``residual" object $M$.
This object is not uniquely determined and in fact we shall make good use of that
Expand Down Expand Up @@ -564,13 +564,13 @@ \subsubsection{Coproduct Structure}
Given our utter disappointment with the product structure, can we have any
hope to lift the co-product structure. Yes, we do! First we construct the
co-product itself. For two optics $\langle l_1 | r_1 \rangle: (A, A') \to (B, B')$
with residual $M_1$ and $\langle l_2 | r_2 \rangle: (C, D') \to (D, D')$ with residual $M_2$,
with residual $M_1$ and $\langle l_2 | r_2 \rangle: (C, C') \to (D, D')$ with residual $M_2$,
we construct a new optic $\langle l_{12} | r_{12} \rangle$ where

\begin{equation}
\begin{split}
l_{12} = (l_1 \oplus l_2) \bbsemi \leftrightarrow_{oplus} \\
r_{12} = \leftrightarrow_{oplus}^{-1} \bbsemi (r_1 \oplus r_2)
l_{12} = (l_1 \oplus l_2) \bbsemi \leftrightarrow_{\oplus} \\
r_{12} = \leftrightarrow_{\oplus}^{-1} \bbsemi (r_1 \oplus r_2)
\end{split}
\end{equation}

Expand Down Expand Up @@ -881,9 +881,9 @@ \subsubsection{Copy}
\end{snippet}

However, note that while this is a valid definition under our definition of
an optic functor, applying $textbf{\euro{}}$ now leads to accumulation order
an optic functor, applying $\textbf{\euro{}}$ now leads to accumulation order
dependence (the same happens in the variant where cloning is done once per value).
As a result, $textbf{\euro{}}$ would no longer preserve standard SSA invariants.
As a result, $\textbf{\euro{}}$ would no longer preserve standard SSA invariants.
This is legal according to our definition, but it can be convenient to be able to
arbitrarily permute SSA transforms and optic functors. Thus, we would generally
only ever choose one of the first two definitions.
Expand Down
75 changes: 34 additions & 41 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ function (g::∇getindex)(Δ)
(ChainRulesCore.NoTangent(), Δ′, map(_ -> nothing, g.i)...)
end

function ChainRulesCore.rrule(g::∇getindex, Δ)
function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ)
g(Δ), Δ′′->(nothing, Δ′′[1][g.i...])
end

function ChainRulesCore.rrule(::typeof(getindex), xs::Array, i...)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array{<:Number}, i...)
xs[i...], ∇getindex(xs, i)
end

Expand All @@ -37,14 +37,14 @@ function assert_gf(f)
@assert sizeof(sin) == 0
end

function ChainRulesCore.rrule(::typeof(assert_gf), f)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(assert_gf), f)
assert_gf(f), Δ->begin
(NoTangent(), NoTangent())
end
end

#=
function ChainRulesCore.rrule(::typeof(map), f, xs::Vector...)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(map), f, xs::Vector...)
assert_gf(f)
primal, dual = reversediff_array(f, xs...)
primal, Δ->begin
Expand Down Expand Up @@ -94,7 +94,7 @@ function ChainRulesCore.frule((_, ∂A, ∂B), ::typeof(*), A::AbstractMatrix{<:
end

#=
function ChainRulesCore.rrule(::typeof(map), f, xs::Vector)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(map), f, xs::Vector)
assert_gf(f)
arrs = reversediff_array(f, xs)
primal = getfield(arrs, 1)
Expand All @@ -105,7 +105,7 @@ end
=#

#=
function ChainRulesCore.rrule(::typeof(map), f, xs::Vector, ys::Vector)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(map), f, xs::Vector, ys::Vector)
assert_gf(f)
arrs = reversediff_array(f, xs, ys)
primal = getfield(arrs, 1)
Expand All @@ -116,14 +116,14 @@ end
=#

xsum(x::Vector) = sum(x)
function ChainRulesCore.rrule(::typeof(xsum), x::Vector)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(xsum), x::Vector)
xsum(x), let xdims=size(x)
Δ->(NoTangent(), xfill(Δ, xdims...))
end
end

xfill(x, dims...) = fill(x, dims...)
function ChainRulesCore.rrule(::typeof(xfill), x, dim)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(xfill), x, dim)
xfill(x, dim), Δ->(NoTangent(), xsum(Δ), NoTangent())
end

Expand All @@ -137,11 +137,11 @@ struct NonDiffOdd{N, O, P}; end
# This should not happen
(::NonDiffEven{N, O, O})(Δ...) where {N, O} = error()

@Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...)
@Base.pure function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...)
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
end

function ChainRulesCore.rrule(::typeof(Core.tuple), args...)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.tuple), args...)
Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...)
end

Expand All @@ -150,12 +150,6 @@ end

ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()

# Skip AD'ing through the axis computation
function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
return Base.Broadcast.instantiate(bc), Δ->begin
Core.tuple(NoTangent(), Δ)
end
end


using StaticArrays
Expand All @@ -169,11 +163,11 @@ struct to_tuple{N}; end
end
(::to_tuple)(Δ::SArray) = getfield(Δ, :data)

function ChainRules.rrule(::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
SArray{S, T, N, L}(x), to_tuple{L}()
end

function ChainRules.rrule(::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
SArray{S, T, N, L}(x), to_tuple{L}()
end

Expand All @@ -187,26 +181,22 @@ end

@ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T)

function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args...)
getindex(A, args...), getindex(∂A, args...)
end

function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray)
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray)
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
end

function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractVector, B::AbstractVector)
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractVector, B::AbstractVector)
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
end

function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N}
function ChainRules.rrule(::DiffractorRuleConfig, AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N}
# We're leaving these in the eltype that the cotangent vector already has.
# There isn't really a good reason to believe we should convert to the
# original array type, so don't unless explicitly requested.
AT(x), Δ->(NoTangent(), Δ)
end

function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...)
function ChainRules.rrule(::DiffractorRuleConfig, AT::Type{<:Array}, undef::UndefInitializer, args...)
# We're leaving these in the eltype that the cotangent vector already has.
# There isn't really a good reason to believe we should convert to the
# original array type, so don't unless explicitly requested.
Expand All @@ -217,38 +207,39 @@ function unzip_tuple(t::Tuple)
map(x->x[1], t), map(x->x[2], t)
end

function ChainRules.rrule(::typeof(unzip_tuple), args::Tuple)
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(unzip_tuple), args::Tuple)
unzip_tuple(args), Δ->(NoTangent(), map((x,y)->(x,y), Δ...))
end

struct BackMap{T}
f::T
end
(f::BackMap{N})(args...) where {N} = ∂⃖¹(getfield(f, :f), args...)
back_apply(x, y) = x(y)
back_apply_zero(x) = x(Zero())
back_apply(x, y) = x(y) # this is just |> with arguments reversed
back_apply_zero(x) = x(Zero()) # Zero is not defined

function ChainRules.rrule(::typeof(map), f, args::Tuple)
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple)
a, b = unzip_tuple(map(BackMap(f), args))
function back(Δ)
function map_back(Δ)
(fs, xs) = unzip_tuple(map(back_apply, b, Δ))
(NoTangent(), sum(fs), xs)
end
function back::ZeroTangent)
(fs, xs) = unzip_tuple(map(back_apply_zero, b))
(NoTangent(), sum(fs), xs)
end
a, back
map_back::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
a, map_back
end

function ChainRules.rrule(::typeof(Base.ntuple), f, n)
ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple{}) = (), _ -> (NoTangent(), NoTangent(), NoTangent())

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.ntuple), f, n)
a, b = unzip_tuple(ntuple(BackMap(f), n))
a, function (Δ)
function ntuple_back(Δ)
(NoTangent(), sum(map(back_apply, b, Δ)), NoTangent())
end
ntuple_back(::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
a, ntuple_back
end

function ChainRules.frule(_, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}
function ChainRules.frule(::DiffractorRuleConfig, _, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}
Vector{T}(undef, dims...), zeros(T, dims...)
end

Expand All @@ -258,11 +249,13 @@ end
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()

# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
function ChainRulesCore.rrule(::Type{Thunk}, thnk)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{Thunk}, thnk)
z, ∂z = ∂⃖¹(thnk)
z, Δ->(NoTangent(), ∂z(Δ)...)
end

function ChainRulesCore.rrule(::Type{InplaceableThunk}, add!!, val)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk}, add!!, val)
val, Δ->(NoTangent(), NoTangent(), Δ)
end

Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581
6 changes: 4 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ However, users may provide additional overloads for custom representations of
one dimensional Riemannian manifolds.
"""
dx(x::Real) = one(x)
dx(::NoTangent) = NoTangent()
dx(::ZeroTangent) = ZeroTangent()
dx(x::Complex) = error("Tried to take the gradient of a complex-valued function.")
dx(x) = error("Cotangent space not defined for `$(typeof(x))`. Try a real-valued function.")

Expand Down Expand Up @@ -125,7 +127,7 @@ end
# N.B: This means the gradient is not available for zero-arg function, but such
# a gradient would be guaranteed to be `()`, which is a bit of a useless thing
function (::Type{∇})(f, x1, args...)
(f)(x1, args...)
unthunk.((f)(x1, args...))
end

const gradient =
Expand Down Expand Up @@ -157,7 +159,7 @@ function (f::PrimeDerivativeBack)(x)
z = ∂⃖¹(lower_pd(f), x)
y = getfield(z, 1)
f☆ = getfield(z, 2)
return getfield(f☆(dx(y)), 2)
return unthunk(getfield(f☆(dx(y)), 2))
end

# Forwards primal derivative
Expand Down
21 changes: 18 additions & 3 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,26 @@ struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}}
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
@Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple)
fnames = union(fieldnames(x), fieldnames(y))
isempty(fnames) && return :((;)) # code below makes () instead
gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent())
grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent())
Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...)
end
@Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...)
@Base.constprop :aggressive accum(a::NoTangent, b) = b
@Base.constprop :aggressive accum(a, b::NoTangent) = a
@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()
@Base.constprop :aggressive accum(a::AbstractZero, b) = b
@Base.constprop :aggressive accum(a, b::AbstractZero) = a
@Base.constprop :aggressive accum(a::AbstractZero, b::AbstractZero) = NoTangent()

using ChainRulesCore: Tangent, backing

function accum(x::Tangent{T}, y::NamedTuple) where T
# @warn "gradient is both a Tangent and a NamedTuple" x y
_tangent(T, accum(backing(x), y))
end
accum(x::NamedTuple, y::Tangent) = accum(y, x)
# This solves an ambiguity, but also avoids Tangent{ZeroTangent}() which + does not:
accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing(y)))

_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
_tangent(::Type, ::NamedTuple{()}) = NoTangent()
_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent()
Loading

0 comments on commit 45f04ae

Please sign in to comment.