Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove BijectorsEnzymeExt on 1.11.1 #337

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ on:
jobs:
test:
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
fail-fast: false
matrix:
version:
- '1.6'
Expand All @@ -23,13 +23,13 @@ jobs:
AD:
- Enzyme
- ForwardDiff
- Tapir
- Mooncake
- Tracker
- ReverseDiff
- Zygote
exclude:
- version: 1.6
AD: Tapir
AD: Mooncake
# TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see
# discussion in https://github.com/TuringLang/Bijectors.jl/pull.
- version: 1.6
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Interface.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ on:
jobs:
test:
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
fail-fast: false
matrix:
version:
- '1.6'
Expand Down
15 changes: 9 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.18"
version = "0.14.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand All @@ -26,21 +26,22 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsEnzymeExt = "Enzyme"
BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"]
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsMooncakeExt = "Mooncake"
BijectorsTrackerExt = "Tracker"
BijectorsTapirExt = "Tapir"
BijectorsZygoteExt = "Zygote"

[compat]
Expand All @@ -53,6 +54,7 @@ Distributions = "0.25.33"
DistributionsAD = "0.6"
DocStringExtensions = "0.9"
Enzyme = "0.12.22"
EnzymeCore = "0.7.8"
ForwardDiff = "0.10"
Functors = "0.1, 0.2, 0.3, 0.4"
InverseFunctions = "0.1"
Expand All @@ -65,17 +67,18 @@ Requires = "0.5, 1"
ReverseDiff = "1"
Roots = "1.3.4, 2"
Statistics = "1"
Tapir = "0.2.23"
Mooncake = "0.4.19"
Tracker = "0.2"
Zygote = "0.6.63"
julia = "1.6"

[extras]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
14 changes: 9 additions & 5 deletions ext/BijectorsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
module BijectorsEnzymeExt

if isdefined(Base, :get_extension)
using Enzyme: @import_frule, @import_rrule
using Enzyme: @import_rrule, @import_frule
using Bijectors: find_alpha
else
using ..Enzyme: @import_frule, @import_rrule
using ..Enzyme: @import_rrule, @import_frule
using ..Bijectors: find_alpha
end

@import_rrule typeof(find_alpha) Real Real Real
@import_frule typeof(find_alpha) Real Real Real

@static if v"1.11.1" <= VERSION < v"1.12"
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
@warn "Bijectors and Enzyme do not work together on Julia $VERSION"
else
@import_rrule typeof(find_alpha) Real Real Real
@import_frule typeof(find_alpha) Real Real Real
end

end # module
15 changes: 8 additions & 7 deletions ext/BijectorsTapirExt.jl → ext/BijectorsMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module BijectorsTapirExt
module BijectorsMooncakeExt

if isdefined(Base, :get_extension)
using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule
using Mooncake:
@is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule
using Bijectors: find_alpha, ChainRulesCore
else
using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule
using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule
using ..Bijectors: find_alpha, ChainRulesCore
end

Expand All @@ -19,20 +20,20 @@ end
# unusual Integer type is encountered.
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})

function Tapir.rrule!!(
function Mooncake.rrule!!(
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I}
) where {P<:Base.IEEEFloat,I<:Integer}
# Require that the integer is non-differentiable.
if tangent_type(I) != Tapir.NoTangent
if tangent_type(I) != Mooncake.NoTangent
msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent."
throw(ArgumentError(msg))
end
out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z))
function find_alpha_pb(dout::P)
_, dx, dy, _ = pb(dout)
return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData()
return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData()
end
return Tapir.zero_fcodual(out), find_alpha_pb
return Mooncake.zero_fcodual(out), find_alpha_pb
end

end
14 changes: 7 additions & 7 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,37 @@ end
test_frule(Bijectors.find_alpha, x, y, z)
test_rrule(Bijectors.find_alpha, x, y, z)

if @isdefined Tapir
if @isdefined Mooncake
rng = Xoshiro(123456)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
z;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
3;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
UInt32(3);
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
end

Expand Down
54 changes: 33 additions & 21 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
b in (
:ForwardDiff,
:Zygote,
:Tapir,
:Mooncake,
:ReverseDiff,
:Enzyme,
:EnzymeForward,
Expand Down Expand Up @@ -78,27 +78,39 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
end
end

if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10"
rule = Tapir.build_rrule(f, x; safety_on=false)
if :tapir in broken
@test_broken(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
else
@test(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10"
try
Mooncake.build_rrule(f, x)
catch exc
# TODO(penelopeysm):
# @test_throws AssertionError (expr...) doesn't work, unclear why
@test exc isa AssertionError
end
# TODO: The above @test_throws happens because of
# https://github.com/compintell/Mooncake.jl/issues/319. If that test
# fails, it probably means that the issue was fixed, in which case
# we can remove that block and uncomment the following instead.

# rule = Mooncake.build_rrule(f, x)
# if :Mooncake in broken
# @test_broken (
# isapprox(
# Mooncake.value_and_gradient!!(rule, f, x)[2][2],
# finitediff;
# rtol=rtol,
# atol=atol,
# )
# )
# else
# @test(
# isapprox(
# Mooncake.value_and_gradient!!(rule, f, x)[2][2],
# finitediff;
# rtol=rtol,
# atol=atol,
# )
# )
# end
end

return nothing
Expand Down
4 changes: 2 additions & 2 deletions test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ end
end
end
# Check that the quantiles are reasonable, i.e. within
# 5 standard errors of the true quantiles (and that the MCSE is
# 6 standard errors of the true quantiles (and that the MCSE is
# not too large).
for i in 1:k
for j in 1:length(qts)
@test qs_mcse[i, j] < abs(qs_true[i, end] - qs_true[i, 1]) / 2
@test abs(qs[i, j] - qs_true[i, j]) < 5 * qs_mcse[i, j]
@test abs(qs[i, j] - qs_true[i, j]) < 6 * qs_mcse[i, j]
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ if VERSION < v"1.9"
using Compat: stack
end

# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing
# Sadly, Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're testing
# on at least version 1.10.
if VERSION >= v"1.10"
using Pkg
Pkg.add("Tapir")
using Tapir
Pkg.add("Mooncake")
using Mooncake
end

const GROUP = get(ENV, "GROUP", "All")
Expand Down
Loading