Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Tapir.jl Implementation #76

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = [
"Vaibhav Dixit <vaibhavyashdixit@gmail.com>, Guillaume Dalle and contributors",
]
version = "1.6.1"
version = "1.6.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
20 changes: 9 additions & 11 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,33 +325,31 @@ mode(::AutoSymbolics) = SymbolicMode()
"""
AutoTapir

Struct used to select the [Tapir.jl](https://github.com/withbayes/Tapir.jl) backend for automatic differentiation.
Struct used to select the [Tapir.jl](https://github.com/compintell/Tapir.jl) backend for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoTapir(; safe_mode=true)
AutoTapir(; debug_mode::Bool)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt Does this allow constructors like AutoTapir(false)? It feels a bit inconvenient if we force everyone to type AutoTapir(debug_mode=false) all the time, particularly for REPL.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All other ADTypes constructors expect keyword arguments, the positional arguments are not documented and thus not part of the API. AutoTapir should not be different in this regard


# Fields

- `safe_mode::Bool`: whether to run additional checks to catch errors early. While this is
on by default to ensure that users are aware of this option, you should generally turn
it off for actual use, as it has substantial performance implications.
If you encounter a problem with using Tapir (it fails to differentiate a function, or
something truly nasty like a segfault occurs), then you should try switching `safe_mode`
on and look at what happens. Often errors are caught earlier and the error messages are
more useful.
- `debug_mode::Bool`: whether to run additional checks to catch errors early. This should
be set to `false` in general use of the package. If you encounter a problem when using
Tapir.jl (it fails to differentiate a function, or something truly nasty like a segfault
occurs), then you should switch `debug_mode` on. This often results in errors being
caught earlier in execution, and the associated error messages being more useful.
"""
Base.@kwdef struct AutoTapir <: AbstractADType
safe_mode::Bool = true
debug_mode::Bool
end

mode(::AutoTapir) = ReverseMode()

function Base.show(io::IO, backend::AutoTapir)
print(io, AutoTapir, "(")
!(backend.safe_mode) && print(io, "safe_mode=false")
print(io, "debug_mode=$(backend.debug_mode)")
print(io, ")")
end

Expand Down
8 changes: 4 additions & 4 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ end
end

@testset "AutoTapir" begin
ad = AutoTapir()
ad = AutoTapir(debug_mode=true)
@test ad isa AbstractADType
@test ad isa AutoTapir
@test mode(ad) isa ReverseMode
@test ad.safe_mode
@test ad.debug_mode

ad = AutoTapir(; safe_mode = false)
@test !ad.safe_mode
ad = AutoTapir(; debug_mode = false)
@test !ad.debug_mode
end

@testset "AutoTracker" begin
Expand Down
4 changes: 2 additions & 2 deletions test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ for backend in [
ADTypes.AutoReverseDiff(),
ADTypes.AutoReverseDiff(compile = true),
ADTypes.AutoSymbolics(),
ADTypes.AutoTapir(),
ADTypes.AutoTapir(safe_mode = false),
ADTypes.AutoTapir(debug_mode = true),
ADTypes.AutoTapir(debug_mode = false),
ADTypes.AutoTracker(),
ADTypes.AutoZygote(),
# sparse
Expand Down
7 changes: 4 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ function every_ad()
AutoPolyesterForwardDiff(),
AutoReverseDiff(),
AutoSymbolics(),
AutoTapir(),
AutoTapir(debug_mode = true),
AutoTapir(debug_mode = false),
AutoTracker(),
AutoZygote()
]
Expand All @@ -69,8 +70,8 @@ function every_ad_with_options()
AutoReverseDiff(),
AutoReverseDiff(compile = true),
AutoSymbolics(),
AutoTapir(),
AutoTapir(safe_mode = false),
AutoTapir(debug_mode = true),
AutoTapir(debug_mode = false),
AutoTracker(),
AutoZygote()
]
Expand Down
2 changes: 1 addition & 1 deletion test/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Test
@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff
@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff
@test ADTypes.Auto(:Symbolics) isa AutoSymbolics
@test ADTypes.Auto(:Tapir) isa AutoTapir
@test ADTypes.Auto(:Tapir; debug_mode=false) isa AutoTapir
@test ADTypes.Auto(:Tracker) isa AutoTracker
@test ADTypes.Auto(:Zygote) isa AutoZygote

Expand Down
Loading