Skip to content

Commit

Permalink
Add randn_tangent to match zero_tangent fix for circular referenc…
Browse files Browse the repository at this point in the history
…es (#228)

* introduce `randn_tangent_internal`

* update tests

* some modification to `has_equal_data` still infinite recurse

* some code comments, not working yet

* some update, no test yet

* fix `has_equal_data`

* fix `increment!!`

* more resolutions

* more fix

* remove some changes to limit the scope

* bump patch version

* improve the comment

* remove one wrong test case

* remove modifications to `populate_address_map!`

* move some common circular ref test case to TestResources

* version bump

* Remove two-argument randn_tangent methods

* remove unreasonable test

* fix type stability issue

* simplify `zero_tangent` like `randn_tangent`

* Update test/tangents.jl

Co-authored-by: Will Tebbutt <wct23@cam.ac.uk>

* apply Will's suggestions

* fix test cov

* fix error in test

* add more test for coverage

---------

Co-authored-by: willtebbutt <wtebbutt@turing.ac.uk>
Co-authored-by: Will Tebbutt <wct23@cam.ac.uk>
  • Loading branch information
3 people authored Sep 17, 2024
1 parent 0390b34 commit eb4f54c
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 152 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.50"
version = "0.2.51"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
148 changes: 65 additions & 83 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,42 +428,16 @@ handles both circular references and aliasing correctly.
"""
zero_tangent(x)
function zero_tangent(x::P) where {P}
return isbitstype(P) ? zero_tangent_internal(x) : zero_tangent_internal(x, IdDict())
end

@inline zero_tangent_internal(::Union{Int8, Int16, Int32, Int64, Int128}) = NoTangent()
@inline zero_tangent_internal(x::IEEEFloat) = zero(x)
@inline function zero_tangent_internal(x::P) where {P<:Union{Tuple, NamedTuple}}
return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(zero_tangent_internal, x)
end
@generated function zero_tangent_internal(x::P) where P

tangent_type(P) == NoTangent && return NoTangent()

# This method can only handle struct types. Tell user to implement tangent type
# directly for primitive types.
isprimitivetype(P) && throw(error(
"$P is a primitive type. Implement a method of `zero_tangent` for it."
))

# Derive zero tangent. Tangent types of fields, and types of zeros need only agree
# if field types are concrete.
tangent_field_zeros_exprs = ntuple(fieldcount(P)) do n
if tangent_field_type(P, n) <: PossiblyUninitTangent
V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))}
return :(isdefined(x, $n) ? $V(zero_tangent_internal(getfield(x, $n))) : $V())
else
return :(zero_tangent_internal(getfield(x, $n)))
end
end
backing_data_expr = Expr(:call, :tuple, tangent_field_zeros_exprs...)
backing_expr = :($(backing_type(P))($backing_data_expr))
return :($(tangent_type(P))($backing_expr))
return zero_tangent_internal(x, isbitstype(P) ? nothing : IdDict())
end

# the `stackdict` naming following convention of Julia's `deepcopy` and `deepcopy_internal`
# https://github.com/JuliaLang/julia/blob/48d4fd48430af58502699fdf3504b90589df3852/base/deepcopy.jl#L35
@inline zero_tangent_internal(x::Union{Int8,Int16,Int32,Int64,Int128,IEEEFloat}, stackdict::IdDict) = zero_tangent_internal(x)
@inline zero_tangent_internal(::Union{Int8, Int16, Int32, Int64, Int128}, ::Any) = NoTangent()
@inline zero_tangent_internal(x::IEEEFloat, ::Any) = zero(x)
@inline function zero_tangent_internal(x::P, stackdict::Any) where {P<:Union{Tuple, NamedTuple}}
return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(Base.Fix2(zero_tangent_internal, stackdict), x)
end
@inline function zero_tangent_internal(x::SimpleVector, stackdict::IdDict)
return map!(n -> zero_tangent_internal(x[n], stackdict), Vector{Any}(undef, length(x)), eachindex(x))
end
Expand All @@ -474,45 +448,42 @@ end
stackdict[x] = zt
return _map_if_assigned!(Base.Fix2(zero_tangent_internal, stackdict), zt, x)::Array{tangent_type(P), N}
end
@inline function zero_tangent_internal(x::P, stackdict::IdDict) where {P<:Union{Tuple, NamedTuple}}
return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(Base.Fix2(zero_tangent_internal, stackdict), x)
end
function zero_tangent_internal(x::P, stackdict::IdDict) where {P}

function zero_tangent_internal(x::P, stackdict) where {P}
tangent_type(P) == NoTangent && return NoTangent()

if tangent_type(P) <: MutableTangent
if !(stackdict isa IdDict)
throw(
ArgumentError(
"Internal error: stackdict must be an IdDict for mutable structs, not $(typeof(stackdict)). Please report this issue."
)
)
end
if haskey(stackdict, x)
return stackdict[x]::tangent_type(P)
end
stackdict[x] = tangent_type(P)() # create a uninitialised MutableTangent
# if circular reference exists, then the recursive call will first look up the stackdict
# and return the uninitialised MutableTangent
# after the recursive call returns, the stackdict will be initialised
stackdict[x].fields = backing_type(P)(zero_tangent_struct_field(x, stackdict))
stackdict[x].fields = zero_tangent_struct_field(x, stackdict)
return stackdict[x]::tangent_type(P)
else
if isbitstype(P)
return zero_tangent_internal(x)
else
return tangent_type(P)(backing_type(P)(zero_tangent_struct_field(x, stackdict)))
end
return tangent_type(P)(zero_tangent_struct_field(x, stackdict))
end
end

@inline function zero_tangent_struct_field(x::P, stackdict::IdDict) where {P}
return ntuple(fieldcount(P)) do n
@generated function zero_tangent_struct_field(x::P, stackdict) where {P}
tangent_field_zeros_exprs = ntuple(fieldcount(P)) do n
if tangent_field_type(P, n) <: PossiblyUninitTangent
V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))}
if isdefined(x, n)
return V(zero_tangent_internal(getfield(x, n), stackdict))
else
return V()
end
return :(isdefined(x, $n) ? $V(zero_tangent_internal(getfield(x, $n), stackdict)) : $V())
else
return zero_tangent_internal(getfield(x, n), stackdict)
return :(zero_tangent_internal(getfield(x, $n), stackdict))
end
end
tangent_fields_expr = Expr(:call, :tuple, tangent_field_zeros_exprs...)
return :($(backing_type(P))($tangent_fields_expr))
end

"""
Expand All @@ -529,47 +500,62 @@ details -- this docstring is intentionally non-specific in order to avoid becomi
Required for testing.
Generate a randomly-chosen tangent to `x`.
The design is closely modelled after `zero_tangent`.
"""
randn_tangent(::AbstractRNG, ::NoTangent) = NoTangent()
randn_tangent(rng::AbstractRNG, ::T) where {T<:IEEEFloat} = randn(rng, T)
function randn_tangent(rng::AbstractRNG, x::Array{T, N}) where {T, N}
dx = Array{tangent_type(T), N}(undef, size(x)...)
return _map_if_assigned!(Base.Fix1(randn_tangent, rng), dx, x)
function randn_tangent(rng::AbstractRNG, x::T) where {T}
return randn_tangent_internal(rng, x, isbitstype(T) ? nothing : IdDict())
end

randn_tangent_internal(::AbstractRNG, ::NoTangent, ::Any) = NoTangent()
randn_tangent_internal(rng::AbstractRNG, ::T, ::Any) where {T<:IEEEFloat} = randn(rng, T)
function randn_tangent_internal(rng::AbstractRNG, x::P, stackdict::Any) where {P<:Union{Tuple, NamedTuple}}
return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(x -> randn_tangent_internal(rng, x, stackdict), x)
end
function randn_tangent(rng::AbstractRNG, x::SimpleVector)
function randn_tangent_internal(rng::AbstractRNG, x::SimpleVector, stackdict::IdDict)
return map!(Vector{Any}(undef, length(x)), eachindex(x)) do n
return randn_tangent(rng, x[n])
return randn_tangent_internal(rng, x[n], stackdict)
end
end
function randn_tangent(rng::AbstractRNG, x::P) where {P <: Union{Tuple, NamedTuple}}
tangent_type(P) == NoTangent && return NoTangent()
return tuple_map(x -> randn_tangent(rng, x), x)
end
function randn_tangent(rng::AbstractRNG, x::T) where {T<:Union{Tangent, MutableTangent}}
return T(randn_tangent(rng, x.fields))
function randn_tangent_internal(rng::AbstractRNG, x::Array{T, N}, stackdict::IdDict) where {T, N}
haskey(stackdict, x) && return stackdict[x]::tangent_type(typeof(x))

dx = Array{tangent_type(T), N}(undef, size(x)...)
stackdict[x] = dx
return _map_if_assigned!(x -> randn_tangent_internal(rng, x, stackdict), dx, x)
end
@generated function randn_tangent(rng::AbstractRNG, x::P) where {P}

# If `P` doesn't have a tangent space, always return `NoTangent()`.
tangent_type(P) === NoTangent && return NoTangent()
function randn_tangent_internal(rng::AbstractRNG, x::P, stackdict) where {P}
tangent_type(P) == NoTangent && return NoTangent()

# This method can only handle struct types. Tell user to implement tangent type
# directly for primitive types.
isprimitivetype(P) && throw(error(
"$P is a primitive type. Implement a method of `randn_tangent` for it."
))
if tangent_type(P) <: MutableTangent
if !(stackdict isa IdDict)
throw(
ArgumentError(
"Internal error: stackdict must be an IdDict for mutable structs, not $(typeof(stackdict)). Please report this issue."
)
)
end
if haskey(stackdict, x)
return stackdict[x]::tangent_type(P)
end
stackdict[x] = tangent_type(P)()
stackdict[x].fields = randn_tangent_struct_field(rng, x, stackdict)
return stackdict[x]::tangent_type(P)
else
return tangent_type(P)(randn_tangent_struct_field(rng, x, stackdict))
end
end

# Assume `P` is a generic struct type, and derive the tangent recursively.
@generated function randn_tangent_struct_field(rng::AbstractRNG, x::P, stackdict) where {P}
tangent_field_exprs = map(1:fieldcount(P)) do n
if tangent_field_type(P, n) <: PossiblyUninitTangent
V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))}
return :(isdefined(x, $n) ? $V(randn_tangent(rng, getfield(x, $n))) : $V())
return :(isdefined(x, $n) ? $V(randn_tangent_internal(rng, getfield(x, $n), stackdict)) : $V())
else
return :(randn_tangent(rng, getfield(x, $n)))
return :(randn_tangent_internal(rng, getfield(x, $n), stackdict))
end
end
tangent_fields_expr = Expr(:call, :tuple, tangent_field_exprs...)
return :($(tangent_type(P))($(backing_type(P))($tangent_fields_expr)))
return :($(backing_type(P))($tangent_fields_expr))
end

"""
Expand Down Expand Up @@ -793,7 +779,7 @@ for T in [Symbol, Int, Val]
@eval increment_field!!(::NoTangent, ::NoTangent, f::Union{$T}) = NoTangent()
end

#=
"""
tangent_test_cases()
Constructs a `Vector` of `Tuple`s containing test cases for the tangent infrastructure.
Expand All @@ -809,13 +795,9 @@ If the returned tuple has 5 elements, then the elements are interpreted as follo
2 - primal value
3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).
Generally speaking, it's very straightforward to produce test cases in the first format,
while the second requires more work. Consequently, at the time of writing there are many
more instances of the first format than the second.
Test cases in the first format make use of `zero_tangent` / `randn_tangent` etc to generate
tangents, but they're unable to check that `increment!!` is correct in an absolute sense.
=#
"""
function tangent_test_cases()

N_large = 33
Expand Down
Loading

2 comments on commit eb4f54c

@sunxd3
Copy link
Collaborator Author

@sunxd3 sunxd3 commented on eb4f54c Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/115344

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.51 -m "<description of version>" eb4f54c228370326856dbb3582409f6186dd99ee
git push origin v0.2.51

Please sign in to comment.