From 18b3b1c16977709055dc62579751a7e7e28bbee8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 24 Oct 2024 21:19:28 +0100 Subject: [PATCH 1/5] Disable fail-fast on CI --- .github/workflows/AD.yml | 2 +- .github/workflows/Interface.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 47ef8549..3777b346 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -9,8 +9,8 @@ on: jobs: test: runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.version == 'nightly' }} strategy: + fail-fast: false matrix: version: - '1.6' diff --git a/.github/workflows/Interface.yml b/.github/workflows/Interface.yml index ef1f4dc7..b305124e 100644 --- a/.github/workflows/Interface.yml +++ b/.github/workflows/Interface.yml @@ -10,8 +10,8 @@ on: jobs: test: runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.version == 'nightly' }} strategy: + fail-fast: false matrix: version: - '1.6' From 6952360633936194f5dbe162bd1930178346f461 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 28 Oct 2024 20:07:29 +0000 Subject: [PATCH 2/5] Inline expanded frule and rrule in BijectorsEnzymeExt --- Project.toml | 9 +- ext/BijectorsEnzymeExt.jl | 597 +++++++++++++++++++++++++++++++++++++- 2 files changed, 597 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index d13289bf..5f4ca470 100644 --- a/Project.toml +++ b/Project.toml @@ -26,21 +26,22 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BijectorsDistributionsADExt = "DistributionsAD" -BijectorsEnzymeExt = "Enzyme" +BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"] BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" -BijectorsTrackerExt = "Tracker" BijectorsTapirExt = "Tapir" +BijectorsTrackerExt = "Tracker" BijectorsZygoteExt = "Zygote" [compat] @@ -53,6 +54,7 @@ Distributions = "0.25.33" DistributionsAD = "0.6" DocStringExtensions = "0.9" Enzyme = "0.12.22" +EnzymeCore = "0.7.8" ForwardDiff = "0.10" Functors = "0.1, 0.2, 0.3, 0.4" InverseFunctions = "0.1" @@ -73,6 +75,7 @@ julia = "1.6" [extras] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl index 1e8d8aa3..57f2e4b0 100644 --- a/ext/BijectorsEnzymeExt.jl +++ b/ext/BijectorsEnzymeExt.jl @@ -1,14 +1,599 @@ module BijectorsEnzymeExt if isdefined(Base, :get_extension) - using Enzyme: @import_frule, @import_rrule - using Bijectors: find_alpha + using Enzyme: Enzyme + using EnzymeCore: EnzymeCore + using Bijectors: Bijectors, ChainRulesCore else - using ..Enzyme: @import_frule, @import_rrule - using ..Bijectors: find_alpha + using ..Enzyme: Enzyme + using ..EnzymeCore: EnzymeCore + using ..Bijectors: Bijectors, ChainRulesCore end -@import_rrule typeof(find_alpha) Real Real Real -@import_frule typeof(find_alpha) Real Real Real +#= NOTE(penelopeysm): +Changes made to the way extensions were loaded in Julia 1.11.1 mean that it +is no longer sufficient to call Enzyme.@import_rrule and +Enzyme.@import_frule, as we did in previous versions. This is because both of +those macros rely on a method which is defined in EnzymeChainRulesCoreExt, +and on 1.11.1+, that extension is _not_ loaded before BijectorsEnzymeExt is +loaded. (In the past, for reasons which are not fully clear, +EnzymeChainRulesCoreExt _does_ get loaded first.) +See https://github.com/TuringLang/Bijectors.jl/pull/333 for further context. + +However, on versions of Julia where the 'default' extension resolution occurs, +we can still use the macros (see the else clause below). We do this to ensure +that the code is compatible with what may potentially be different versions of +Enzyme. + +The code in the if clause was derived by calling @macroexpand on @import_rrule +and @import_frule, then replacing `$(Expr(:meta, :inline))` with +`Base.@_inline_meta`. + +Note that this was done using Enzyme v0.12.36. This code will fail to track any +upstream changes to EnzymeChainRulesCoreExt, so there is no guarantee that this +code will work with later versions of Enzyme. +=# +@static if v"1.11.1" <= VERSION < v"1.12" + function (Enzyme.EnzymeRules).augmented_primal( + var"#238#config", + var"#239#fn"::var"#246#FA", + ::Enzyme.Type{var"#245#RetAnnotation"}, + var"#241#arg_1"::var"#247#AN_1", + var"#242#arg_2"::var"#248#AN_2", + var"#243#arg_3"::var"#249#AN_3"; + var"#244#kwargs"..., + ) where { + var"#245#RetAnnotation", + var"#246#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, + var"#247#AN_1"<:Enzyme.Annotation{<:Real}, + var"#248#AN_2"<:Enzyme.Annotation{<:Real}, + var"#249#AN_3"<:Enzyme.Annotation{<:Real}, + } + var"#231#primcopy_1" = + if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[1 + 1] + Enzyme.deepcopy((var"#241#arg_1").val) + else + (var"#241#arg_1").val + end + var"#232#primcopy_2" = + if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[2 + 1] + Enzyme.deepcopy((var"#242#arg_2").val) + else + (var"#242#arg_2").val + end + var"#233#primcopy_3" = + if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[3 + 1] + Enzyme.deepcopy((var"#243#arg_3").val) + else + (var"#243#arg_3").val + end + (var"#234#res", var"#235#pullback") = if var"#245#RetAnnotation" <: Enzyme.Const + ( + (var"#239#fn").val( + var"#231#primcopy_1", + var"#232#primcopy_2", + var"#233#primcopy_3"; + var"#244#kwargs"..., + ), + Enzyme.nothing, + ) + else + (ChainRulesCore).rrule( + (var"#239#fn").val, + var"#231#primcopy_1", + var"#232#primcopy_2", + var"#233#primcopy_3"; + var"#244#kwargs"..., + ) + end + var"#236#primal" = if (Enzyme.EnzymeRules).needs_primal(var"#238#config") + var"#234#res" + else + Enzyme.nothing + end + var"#237#shadow" = if !((Enzyme.EnzymeRules).needs_shadow(var"#238#config")) + Enzyme.nothing + else + if (Enzyme.EnzymeRules).width(var"#238#config") == 1 + (Enzyme.Enzyme).make_zero(var"#234#res") + else + Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#238#config")) + ) do var"#250#j" + Base.@_inline_meta + (Enzyme.Enzyme).make_zero(var"#234#res") + end + end + end + return (Enzyme.EnzymeRules).AugmentedReturn( + var"#236#primal", var"#237#shadow", (var"#237#shadow", var"#235#pullback") + ) + end + + function (Enzyme.EnzymeRules).reverse( + var"#254#config", + var"#255#fn"::var"#264#FA", + ::Enzyme.Type{var"#262#RetAnnotation"}, + var"#257#tape"::var"#263#TapeTy", + var"#258#arg_1"::var"#265#AN_1", + var"#259#arg_2"::var"#266#AN_2", + var"#260#arg_3"::var"#267#AN_3"; + var"#261#kwargs"..., + ) where { + var"#262#RetAnnotation", + var"#263#TapeTy", + var"#264#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, + var"#265#AN_1"<:Enzyme.Annotation{<:Real}, + var"#266#AN_2"<:Enzyme.Annotation{<:Real}, + var"#267#AN_3"<:Enzyme.Annotation{<:Real}, + } + if !(var"#262#RetAnnotation" <: Enzyme.Const) + (var"#251#shadow", var"#252#pullback") = var"#257#tape" + var"#253#tcomb" = Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) + ) do var"#272#batch_i" + Base.@_inline_meta + var"#268#shad" = if (Enzyme.EnzymeRules).width(var"#254#config") == 1 + var"#251#shadow" + else + var"#251#shadow"[var"#272#batch_i"] + end + var"#269#res" = var"#252#pullback"(var"#268#shad") + for (var"#270#cr", var"#271#en") in Enzyme.zip( + var"#269#res", + (var"#255#fn", var"#258#arg_1", var"#259#arg_2", var"#260#arg_3"), + ) + if var"#271#en" isa Enzyme.Const || + var"#270#cr" isa (ChainRulesCore).NoTangent + continue + end + if var"#271#en" isa Enzyme.Active + continue + end + if (Enzyme.EnzymeRules).width(var"#254#config") == 1 + (var"#271#en").dval .+= var"#270#cr" + else + (var"#271#en").dval[var"#272#batch_i"] .+= var"#270#cr" + end + end + ( + if var"#255#fn" isa Enzyme.Active + var"#269#res"[1] + else + Enzyme.nothing + end, + if var"#258#arg_1" isa Enzyme.Active + if var"#269#res"[1 + 1] isa (ChainRulesCore).NoTangent + Enzyme.zero(var"#258#arg_1") + else + (ChainRulesCore).unthunk(var"#269#res"[1 + 1]) + end + else + Enzyme.nothing + end, + if var"#259#arg_2" isa Enzyme.Active + if var"#269#res"[2 + 1] isa (ChainRulesCore).NoTangent + Enzyme.zero(var"#259#arg_2") + else + (ChainRulesCore).unthunk(var"#269#res"[2 + 1]) + end + else + Enzyme.nothing + end, + if var"#260#arg_3" isa Enzyme.Active + if var"#269#res"[3 + 1] isa (ChainRulesCore).NoTangent + Enzyme.zero(var"#260#arg_3") + else + (ChainRulesCore).unthunk(var"#269#res"[3 + 1]) + end + else + Enzyme.nothing + end, + ) + end + return ( + begin + if var"#258#arg_1" isa Enzyme.Active + if (Enzyme.EnzymeRules).width(var"#254#config") == 1 + (var"#253#tcomb"[1])[1 + 1] + else + Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) + ) do var"#273#batch_i" + Base.@_inline_meta + (var"#253#tcomb"[var"#273#batch_i"])[1 + 1] + end + end + else + Enzyme.nothing + end + end, + begin + if var"#259#arg_2" isa Enzyme.Active + if (Enzyme.EnzymeRules).width(var"#254#config") == 1 + (var"#253#tcomb"[1])[2 + 1] + else + Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) + ) do var"#274#batch_i" + Base.@_inline_meta + (var"#253#tcomb"[var"#274#batch_i"])[2 + 1] + end + end + else + Enzyme.nothing + end + end, + begin + if var"#260#arg_3" isa Enzyme.Active + if (Enzyme.EnzymeRules).width(var"#254#config") == 1 + (var"#253#tcomb"[1])[3 + 1] + else + Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) + ) do var"#275#batch_i" + Base.@_inline_meta + (var"#253#tcomb"[var"#275#batch_i"])[3 + 1] + end + end + else + Enzyme.nothing + end + end, + ) + end + return (Enzyme.nothing, Enzyme.nothing, Enzyme.nothing) + end + + function (Enzyme.EnzymeRules).reverse( + var"#280#config", + var"#281#fn"::var"#290#FA", + var"#282#dval"::Enzyme.Active{var"#288#RetAnnotation"}, + var"#283#tape"::var"#289#TapeTy", + var"#284#arg_1"::var"#291#AN_1", + var"#285#arg_2"::var"#292#AN_2", + var"#286#arg_3"::var"#293#AN_3"; + var"#287#kwargs"..., + ) where { + var"#288#RetAnnotation", + var"#289#TapeTy", + var"#290#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, + var"#291#AN_1"<:Enzyme.Annotation{<:Real}, + var"#292#AN_2"<:Enzyme.Annotation{<:Real}, + var"#293#AN_3"<:Enzyme.Annotation{<:Real}, + } + (var"#276#oldshadow", var"#277#pullback") = var"#283#tape" + var"#278#shadow" = (var"#282#dval").val + var"#279#tcomb" = Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) + ) do var"#298#batch_i" + Base.@_inline_meta + var"#294#shad" = if (Enzyme.EnzymeRules).width(var"#280#config") == 1 + var"#278#shadow" + else + var"#278#shadow"[var"#298#batch_i"] + end + var"#295#res" = var"#277#pullback"(var"#294#shad") + for (var"#296#cr", var"#297#en") in Enzyme.zip( + var"#295#res", + (var"#281#fn", var"#284#arg_1", var"#285#arg_2", var"#286#arg_3"), + ) + if var"#297#en" isa Enzyme.Const || var"#296#cr" isa (ChainRulesCore).NoTangent + continue + end + if var"#297#en" isa Enzyme.Active + continue + end + if (Enzyme.EnzymeRules).width(var"#280#config") == 1 + (var"#297#en").dval .+= var"#296#cr" + else + (var"#297#en").dval[var"#298#batch_i"] .+= var"#296#cr" + end + end + ( + if var"#281#fn" isa Enzyme.Active + var"#295#res"[1] + else + Enzyme.nothing + end, + if var"#284#arg_1" isa Enzyme.Active + if var"#295#res"[1 + 1] isa (ChainRulesCore).NoTangent + Enzyme.zero(var"#284#arg_1") + else + (ChainRulesCore).unthunk(var"#295#res"[1 + 1]) + end + else + Enzyme.nothing + end, + if var"#285#arg_2" isa Enzyme.Active + if var"#295#res"[2 + 1] isa (ChainRulesCore).NoTangent + Enzyme.zero(var"#285#arg_2") + else + (ChainRulesCore).unthunk(var"#295#res"[2 + 1]) + end + else + Enzyme.nothing + end, + if var"#286#arg_3" isa Enzyme.Active + if var"#295#res"[3 + 1] isa (ChainRulesCore).NoTangent + Enzyme.zero(var"#286#arg_3") + else + (ChainRulesCore).unthunk(var"#295#res"[3 + 1]) + end + else + Enzyme.nothing + end, + ) + end + return ( + begin + if var"#284#arg_1" isa Enzyme.Active + if (Enzyme.EnzymeRules).width(var"#280#config") == 1 + (var"#279#tcomb"[1])[1 + 1] + else + Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) + ) do var"#299#batch_i" + Base.@_inline_meta + (var"#279#tcomb"[var"#299#batch_i"])[1 + 1] + end + end + else + Enzyme.nothing + end + end, + begin + if var"#285#arg_2" isa Enzyme.Active + if (Enzyme.EnzymeRules).width(var"#280#config") == 1 + (var"#279#tcomb"[1])[2 + 1] + else + Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) + ) do var"#300#batch_i" + Base.@_inline_meta + (var"#279#tcomb"[var"#300#batch_i"])[2 + 1] + end + end + else + Enzyme.nothing + end + end, + begin + if var"#286#arg_3" isa Enzyme.Active + if (Enzyme.EnzymeRules).width(var"#280#config") == 1 + (var"#279#tcomb"[1])[3 + 1] + else + Enzyme.ntuple( + Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) + ) do var"#301#batch_i" + Base.@_inline_meta + (var"#279#tcomb"[var"#301#batch_i"])[3 + 1] + end + end + else + Enzyme.nothing + end + end, + ) + end + + function (Enzyme.EnzymeRules).forward( + var"#308#fn"::var"#315#FA", + ::Enzyme.Type{var"#314#RetAnnotation"}, + var"#310#arg_1"::var"#316#AN_1", + var"#311#arg_2"::var"#317#AN_2", + var"#312#arg_3"::var"#318#AN_3"; + var"#313#kwargs"..., + ) where { + var"#314#RetAnnotation", + var"#315#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, + var"#316#AN_1"<:Enzyme.Annotation{<:Real}, + var"#317#AN_2"<:Enzyme.Annotation{<:Real}, + var"#318#AN_3"<:Enzyme.Annotation{<:Real}, + } + var"#302#batchsize" = Enzyme.same_or_one( + 1, var"#310#arg_1", var"#311#arg_2", var"#312#arg_3" + ) + if var"#302#batchsize" == 1 + var"#306#dfn" = if var"#308#fn" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#308#fn").dval + end + var"#303#cres" = (ChainRulesCore).frule( + ( + var"#306#dfn", + if var"#310#arg_1" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#310#arg_1").dval + end, + if var"#311#arg_2" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#311#arg_2").dval + end, + if var"#312#arg_3" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#312#arg_3").dval + end, + ), + (var"#308#fn").val, + (var"#310#arg_1").val, + (var"#311#arg_2").val, + (var"#312#arg_3").val; + var"#313#kwargs"..., + ) + if var"#314#RetAnnotation" <: Enzyme.Const + return var"#303#cres"[2]::Enzyme.eltype(var"#314#RetAnnotation") + elseif var"#314#RetAnnotation" <: Enzyme.Duplicated + return Enzyme.Duplicated(var"#303#cres"[1], var"#303#cres"[2]) + elseif var"#314#RetAnnotation" <: Enzyme.DuplicatedNoNeed + return var"#303#cres"[2]::Enzyme.eltype(var"#314#RetAnnotation") + else + if false + nothing + else + Base.throw(Base.AssertionError("false")) + end + end + else + if var"#314#RetAnnotation" <: Enzyme.Const + var"#303#cres" = + Enzyme.ntuple(Enzyme.Val(var"#302#batchsize")) do var"#305#i" + Base.@_inline_meta + var"#306#dfn" = if var"#308#fn" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#308#fn").dval[var"#305#i"] + end + (ChainRulesCore).frule( + ( + var"#306#dfn", + if var"#310#arg_1" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#310#arg_1").dval[var"#305#i"] + end, + if var"#311#arg_2" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#311#arg_2").dval[var"#305#i"] + end, + if var"#312#arg_3" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#312#arg_3").dval[var"#305#i"] + end, + ), + (var"#308#fn").val, + (var"#310#arg_1").val, + (var"#311#arg_2").val, + (var"#312#arg_3").val; + var"#313#kwargs"..., + ) + end + return (var"#303#cres"[1])[2]::Enzyme.eltype(var"#314#RetAnnotation") + elseif var"#314#RetAnnotation" <: Enzyme.BatchDuplicated + var"#304#cres1" = begin + var"#305#i" = 1 + var"#306#dfn" = if var"#308#fn" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#308#fn").dval[var"#305#i"] + end + (ChainRulesCore).frule( + ( + var"#306#dfn", + if var"#310#arg_1" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#310#arg_1").dval[var"#305#i"] + end, + if var"#311#arg_2" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#311#arg_2").dval[var"#305#i"] + end, + if var"#312#arg_3" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#312#arg_3").dval[var"#305#i"] + end, + ), + (var"#308#fn").val, + (var"#310#arg_1").val, + (var"#311#arg_2").val, + (var"#312#arg_3").val; + var"#313#kwargs"..., + ) + end + var"#307#batches" = + Enzyme.ntuple(Enzyme.Val(var"#302#batchsize" - 1)) do var"#323#j" + Base.@_inline_meta + var"#305#i" = var"#323#j" + 1 + var"#306#dfn" = if var"#308#fn" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#308#fn").dval[var"#305#i"] + end + ((ChainRulesCore).frule( + ( + var"#306#dfn", + if var"#310#arg_1" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#310#arg_1").dval[var"#305#i"] + end, + if var"#311#arg_2" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#311#arg_2").dval[var"#305#i"] + end, + if var"#312#arg_3" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#312#arg_3").dval[var"#305#i"] + end, + ), + (var"#308#fn").val, + (var"#310#arg_1").val, + (var"#311#arg_2").val, + (var"#312#arg_3").val; + var"#313#kwargs"..., + ))[2] + end + return Enzyme.BatchDuplicated( + var"#304#cres1"[1], (var"#304#cres1"[2], var"#307#batches"...) + ) + elseif var"#314#RetAnnotation" <: Enzyme.BatchDuplicatedNoNeed + Enzyme.ntuple(Enzyme.Val(var"#302#batchsize")) do var"#305#i" + Base.@_inline_meta + var"#306#dfn" = if var"#308#fn" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#308#fn").dval[var"#305#i"] + end + ((ChainRulesCore).frule( + ( + var"#306#dfn", + if var"#310#arg_1" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#310#arg_1").dval[var"#305#i"] + end, + if var"#311#arg_2" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#311#arg_2").dval[var"#305#i"] + end, + if var"#312#arg_3" isa Enzyme.Const + (ChainRulesCore).NoTangent() + else + (var"#312#arg_3").dval[var"#305#i"] + end, + ), + (var"#308#fn").val, + (var"#310#arg_1").val, + (var"#311#arg_2").val, + (var"#312#arg_3").val; + var"#313#kwargs"..., + ))[2] + end + else + if false + nothing + else + Base.throw(Base.AssertionError("false")) + end + end + end + end +else + Enzyme.@import_rrule typeof(Bijectors.find_alpha) Real Real Real + Enzyme.@import_frule typeof(Bijectors.find_alpha) Real Real Real end + +end # module From d10ad87852f7dc65e40d67ec3e14c87e72d41464 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Oct 2024 03:11:18 +0000 Subject: [PATCH 3/5] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5f4ca470..62400e20 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.18" +version = "0.13.19" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From aed472a12c6f3ff57f514703da5f4128801e3e03 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Oct 2024 16:16:52 +0000 Subject: [PATCH 4/5] Remove BijectorsEnzymeExt on 1.11.1+ --- ext/BijectorsEnzymeExt.jl | 595 +------------------------------------- 1 file changed, 7 insertions(+), 588 deletions(-) diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl index 57f2e4b0..303fd92f 100644 --- a/ext/BijectorsEnzymeExt.jl +++ b/ext/BijectorsEnzymeExt.jl @@ -1,599 +1,18 @@ module BijectorsEnzymeExt if isdefined(Base, :get_extension) - using Enzyme: Enzyme - using EnzymeCore: EnzymeCore - using Bijectors: Bijectors, ChainRulesCore + using Enzyme: @import_rrule, @import_frule + using Bijectors: find_alpha else - using ..Enzyme: Enzyme - using ..EnzymeCore: EnzymeCore - using ..Bijectors: Bijectors, ChainRulesCore + using ..Enzyme: @import_rrule, @import_frule + using ..Bijectors: find_alpha end -#= NOTE(penelopeysm): -Changes made to the way extensions were loaded in Julia 1.11.1 mean that it -is no longer sufficient to call Enzyme.@import_rrule and -Enzyme.@import_frule, as we did in previous versions. This is because both of -those macros rely on a method which is defined in EnzymeChainRulesCoreExt, -and on 1.11.1+, that extension is _not_ loaded before BijectorsEnzymeExt is -loaded. (In the past, for reasons which are not fully clear, -EnzymeChainRulesCoreExt _does_ get loaded first.) - -See https://github.com/TuringLang/Bijectors.jl/pull/333 for further context. - -However, on versions of Julia where the 'default' extension resolution occurs, -we can still use the macros (see the else clause below). We do this to ensure -that the code is compatible with what may potentially be different versions of -Enzyme. - -The code in the if clause was derived by calling @macroexpand on @import_rrule -and @import_frule, then replacing `$(Expr(:meta, :inline))` with -`Base.@_inline_meta`. - -Note that this was done using Enzyme v0.12.36. This code will fail to track any -upstream changes to EnzymeChainRulesCoreExt, so there is no guarantee that this -code will work with later versions of Enzyme. -=# @static if v"1.11.1" <= VERSION < v"1.12" - function (Enzyme.EnzymeRules).augmented_primal( - var"#238#config", - var"#239#fn"::var"#246#FA", - ::Enzyme.Type{var"#245#RetAnnotation"}, - var"#241#arg_1"::var"#247#AN_1", - var"#242#arg_2"::var"#248#AN_2", - var"#243#arg_3"::var"#249#AN_3"; - var"#244#kwargs"..., - ) where { - var"#245#RetAnnotation", - var"#246#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, - var"#247#AN_1"<:Enzyme.Annotation{<:Real}, - var"#248#AN_2"<:Enzyme.Annotation{<:Real}, - var"#249#AN_3"<:Enzyme.Annotation{<:Real}, - } - var"#231#primcopy_1" = - if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[1 + 1] - Enzyme.deepcopy((var"#241#arg_1").val) - else - (var"#241#arg_1").val - end - var"#232#primcopy_2" = - if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[2 + 1] - Enzyme.deepcopy((var"#242#arg_2").val) - else - (var"#242#arg_2").val - end - var"#233#primcopy_3" = - if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[3 + 1] - Enzyme.deepcopy((var"#243#arg_3").val) - else - (var"#243#arg_3").val - end - (var"#234#res", var"#235#pullback") = if var"#245#RetAnnotation" <: Enzyme.Const - ( - (var"#239#fn").val( - var"#231#primcopy_1", - var"#232#primcopy_2", - var"#233#primcopy_3"; - var"#244#kwargs"..., - ), - Enzyme.nothing, - ) - else - (ChainRulesCore).rrule( - (var"#239#fn").val, - var"#231#primcopy_1", - var"#232#primcopy_2", - var"#233#primcopy_3"; - var"#244#kwargs"..., - ) - end - var"#236#primal" = if (Enzyme.EnzymeRules).needs_primal(var"#238#config") - var"#234#res" - else - Enzyme.nothing - end - var"#237#shadow" = if !((Enzyme.EnzymeRules).needs_shadow(var"#238#config")) - Enzyme.nothing - else - if (Enzyme.EnzymeRules).width(var"#238#config") == 1 - (Enzyme.Enzyme).make_zero(var"#234#res") - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#238#config")) - ) do var"#250#j" - Base.@_inline_meta - (Enzyme.Enzyme).make_zero(var"#234#res") - end - end - end - return (Enzyme.EnzymeRules).AugmentedReturn( - var"#236#primal", var"#237#shadow", (var"#237#shadow", var"#235#pullback") - ) - end - - function (Enzyme.EnzymeRules).reverse( - var"#254#config", - var"#255#fn"::var"#264#FA", - ::Enzyme.Type{var"#262#RetAnnotation"}, - var"#257#tape"::var"#263#TapeTy", - var"#258#arg_1"::var"#265#AN_1", - var"#259#arg_2"::var"#266#AN_2", - var"#260#arg_3"::var"#267#AN_3"; - var"#261#kwargs"..., - ) where { - var"#262#RetAnnotation", - var"#263#TapeTy", - var"#264#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, - var"#265#AN_1"<:Enzyme.Annotation{<:Real}, - var"#266#AN_2"<:Enzyme.Annotation{<:Real}, - var"#267#AN_3"<:Enzyme.Annotation{<:Real}, - } - if !(var"#262#RetAnnotation" <: Enzyme.Const) - (var"#251#shadow", var"#252#pullback") = var"#257#tape" - var"#253#tcomb" = Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) - ) do var"#272#batch_i" - Base.@_inline_meta - var"#268#shad" = if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - var"#251#shadow" - else - var"#251#shadow"[var"#272#batch_i"] - end - var"#269#res" = var"#252#pullback"(var"#268#shad") - for (var"#270#cr", var"#271#en") in Enzyme.zip( - var"#269#res", - (var"#255#fn", var"#258#arg_1", var"#259#arg_2", var"#260#arg_3"), - ) - if var"#271#en" isa Enzyme.Const || - var"#270#cr" isa (ChainRulesCore).NoTangent - continue - end - if var"#271#en" isa Enzyme.Active - continue - end - if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - (var"#271#en").dval .+= var"#270#cr" - else - (var"#271#en").dval[var"#272#batch_i"] .+= var"#270#cr" - end - end - ( - if var"#255#fn" isa Enzyme.Active - var"#269#res"[1] - else - Enzyme.nothing - end, - if var"#258#arg_1" isa Enzyme.Active - if var"#269#res"[1 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#258#arg_1") - else - (ChainRulesCore).unthunk(var"#269#res"[1 + 1]) - end - else - Enzyme.nothing - end, - if var"#259#arg_2" isa Enzyme.Active - if var"#269#res"[2 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#259#arg_2") - else - (ChainRulesCore).unthunk(var"#269#res"[2 + 1]) - end - else - Enzyme.nothing - end, - if var"#260#arg_3" isa Enzyme.Active - if var"#269#res"[3 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#260#arg_3") - else - (ChainRulesCore).unthunk(var"#269#res"[3 + 1]) - end - else - Enzyme.nothing - end, - ) - end - return ( - begin - if var"#258#arg_1" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - (var"#253#tcomb"[1])[1 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) - ) do var"#273#batch_i" - Base.@_inline_meta - (var"#253#tcomb"[var"#273#batch_i"])[1 + 1] - end - end - else - Enzyme.nothing - end - end, - begin - if var"#259#arg_2" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - (var"#253#tcomb"[1])[2 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) - ) do var"#274#batch_i" - Base.@_inline_meta - (var"#253#tcomb"[var"#274#batch_i"])[2 + 1] - end - end - else - Enzyme.nothing - end - end, - begin - if var"#260#arg_3" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - (var"#253#tcomb"[1])[3 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) - ) do var"#275#batch_i" - Base.@_inline_meta - (var"#253#tcomb"[var"#275#batch_i"])[3 + 1] - end - end - else - Enzyme.nothing - end - end, - ) - end - return (Enzyme.nothing, Enzyme.nothing, Enzyme.nothing) - end - - function (Enzyme.EnzymeRules).reverse( - var"#280#config", - var"#281#fn"::var"#290#FA", - var"#282#dval"::Enzyme.Active{var"#288#RetAnnotation"}, - var"#283#tape"::var"#289#TapeTy", - var"#284#arg_1"::var"#291#AN_1", - var"#285#arg_2"::var"#292#AN_2", - var"#286#arg_3"::var"#293#AN_3"; - var"#287#kwargs"..., - ) where { - var"#288#RetAnnotation", - var"#289#TapeTy", - var"#290#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, - var"#291#AN_1"<:Enzyme.Annotation{<:Real}, - var"#292#AN_2"<:Enzyme.Annotation{<:Real}, - var"#293#AN_3"<:Enzyme.Annotation{<:Real}, - } - (var"#276#oldshadow", var"#277#pullback") = var"#283#tape" - var"#278#shadow" = (var"#282#dval").val - var"#279#tcomb" = Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) - ) do var"#298#batch_i" - Base.@_inline_meta - var"#294#shad" = if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - var"#278#shadow" - else - var"#278#shadow"[var"#298#batch_i"] - end - var"#295#res" = var"#277#pullback"(var"#294#shad") - for (var"#296#cr", var"#297#en") in Enzyme.zip( - var"#295#res", - (var"#281#fn", var"#284#arg_1", var"#285#arg_2", var"#286#arg_3"), - ) - if var"#297#en" isa Enzyme.Const || var"#296#cr" isa (ChainRulesCore).NoTangent - continue - end - if var"#297#en" isa Enzyme.Active - continue - end - if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - (var"#297#en").dval .+= var"#296#cr" - else - (var"#297#en").dval[var"#298#batch_i"] .+= var"#296#cr" - end - end - ( - if var"#281#fn" isa Enzyme.Active - var"#295#res"[1] - else - Enzyme.nothing - end, - if var"#284#arg_1" isa Enzyme.Active - if var"#295#res"[1 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#284#arg_1") - else - (ChainRulesCore).unthunk(var"#295#res"[1 + 1]) - end - else - Enzyme.nothing - end, - if var"#285#arg_2" isa Enzyme.Active - if var"#295#res"[2 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#285#arg_2") - else - (ChainRulesCore).unthunk(var"#295#res"[2 + 1]) - end - else - Enzyme.nothing - end, - if var"#286#arg_3" isa Enzyme.Active - if var"#295#res"[3 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#286#arg_3") - else - (ChainRulesCore).unthunk(var"#295#res"[3 + 1]) - end - else - Enzyme.nothing - end, - ) - end - return ( - begin - if var"#284#arg_1" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - (var"#279#tcomb"[1])[1 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) - ) do var"#299#batch_i" - Base.@_inline_meta - (var"#279#tcomb"[var"#299#batch_i"])[1 + 1] - end - end - else - Enzyme.nothing - end - end, - begin - if var"#285#arg_2" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - (var"#279#tcomb"[1])[2 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) - ) do var"#300#batch_i" - Base.@_inline_meta - (var"#279#tcomb"[var"#300#batch_i"])[2 + 1] - end - end - else - Enzyme.nothing - end - end, - begin - if var"#286#arg_3" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - (var"#279#tcomb"[1])[3 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) - ) do var"#301#batch_i" - Base.@_inline_meta - (var"#279#tcomb"[var"#301#batch_i"])[3 + 1] - end - end - else - Enzyme.nothing - end - end, - ) - end - - function (Enzyme.EnzymeRules).forward( - var"#308#fn"::var"#315#FA", - ::Enzyme.Type{var"#314#RetAnnotation"}, - var"#310#arg_1"::var"#316#AN_1", - var"#311#arg_2"::var"#317#AN_2", - var"#312#arg_3"::var"#318#AN_3"; - var"#313#kwargs"..., - ) where { - var"#314#RetAnnotation", - var"#315#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, - var"#316#AN_1"<:Enzyme.Annotation{<:Real}, - var"#317#AN_2"<:Enzyme.Annotation{<:Real}, - var"#318#AN_3"<:Enzyme.Annotation{<:Real}, - } - var"#302#batchsize" = Enzyme.same_or_one( - 1, var"#310#arg_1", var"#311#arg_2", var"#312#arg_3" - ) - if var"#302#batchsize" == 1 - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval - end - var"#303#cres" = (ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ) - if var"#314#RetAnnotation" <: Enzyme.Const - return var"#303#cres"[2]::Enzyme.eltype(var"#314#RetAnnotation") - elseif var"#314#RetAnnotation" <: Enzyme.Duplicated - return Enzyme.Duplicated(var"#303#cres"[1], var"#303#cres"[2]) - elseif var"#314#RetAnnotation" <: Enzyme.DuplicatedNoNeed - return var"#303#cres"[2]::Enzyme.eltype(var"#314#RetAnnotation") - else - if false - nothing - else - Base.throw(Base.AssertionError("false")) - end - end - else - if var"#314#RetAnnotation" <: Enzyme.Const - var"#303#cres" = - Enzyme.ntuple(Enzyme.Val(var"#302#batchsize")) do var"#305#i" - Base.@_inline_meta - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval[var"#305#i"] - end - (ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval[var"#305#i"] - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval[var"#305#i"] - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval[var"#305#i"] - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ) - end - return (var"#303#cres"[1])[2]::Enzyme.eltype(var"#314#RetAnnotation") - elseif var"#314#RetAnnotation" <: Enzyme.BatchDuplicated - var"#304#cres1" = begin - var"#305#i" = 1 - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval[var"#305#i"] - end - (ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval[var"#305#i"] - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval[var"#305#i"] - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval[var"#305#i"] - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ) - end - var"#307#batches" = - Enzyme.ntuple(Enzyme.Val(var"#302#batchsize" - 1)) do var"#323#j" - Base.@_inline_meta - var"#305#i" = var"#323#j" + 1 - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval[var"#305#i"] - end - ((ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval[var"#305#i"] - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval[var"#305#i"] - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval[var"#305#i"] - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ))[2] - end - return Enzyme.BatchDuplicated( - var"#304#cres1"[1], (var"#304#cres1"[2], var"#307#batches"...) - ) - elseif var"#314#RetAnnotation" <: Enzyme.BatchDuplicatedNoNeed - Enzyme.ntuple(Enzyme.Val(var"#302#batchsize")) do var"#305#i" - Base.@_inline_meta - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval[var"#305#i"] - end - ((ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval[var"#305#i"] - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval[var"#305#i"] - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval[var"#305#i"] - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ))[2] - end - else - if false - nothing - else - Base.throw(Base.AssertionError("false")) - end - end - end - end + @warn "Bijectors and Enzyme do not work together on Julia $VERSION" else - Enzyme.@import_rrule typeof(Bijectors.find_alpha) Real Real Real - Enzyme.@import_frule typeof(Bijectors.find_alpha) Real Real Real + @import_rrule typeof(find_alpha) Real Real Real + @import_frule typeof(find_alpha) Real Real Real end end # module From 2ee8a5d4637fc1feff9ca2fc45cd402d052d872d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Oct 2024 16:52:01 +0000 Subject: [PATCH 5/5] Tapir -> Mooncake (#338) * Tapir -> Mooncake * Bump minor version * Mark Mooncake test as broken * Remove BijectorsEnzymeExt on 1.11.1+ * Increase tolerance on `ordered` test --- .github/workflows/AD.yml | 4 +- Project.toml | 10 ++-- ...orsTapirExt.jl => BijectorsMooncakeExt.jl} | 15 +++--- test/ad/chainrules.jl | 14 ++--- test/ad/utils.jl | 54 +++++++++++-------- test/bijectors/ordered.jl | 4 +- test/runtests.jl | 6 +-- 7 files changed, 60 insertions(+), 47 deletions(-) rename ext/{BijectorsTapirExt.jl => BijectorsMooncakeExt.jl} (77%) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 3777b346..7d0aa4ae 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -23,13 +23,13 @@ jobs: AD: - Enzyme - ForwardDiff - - Tapir + - Mooncake - Tracker - ReverseDiff - Zygote exclude: - version: 1.6 - AD: Tapir + AD: Mooncake # TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see # discussion in https://github.com/TuringLang/Bijectors.jl/pull. - version: 1.6 diff --git a/Project.toml b/Project.toml index 62400e20..283b706c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.19" +version = "0.14.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -30,7 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -40,7 +40,7 @@ BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"] BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" -BijectorsTapirExt = "Tapir" +BijectorsMooncakeExt = "Mooncake" BijectorsTrackerExt = "Tracker" BijectorsZygoteExt = "Zygote" @@ -67,7 +67,7 @@ Requires = "0.5, 1" ReverseDiff = "1" Roots = "1.3.4, 2" Statistics = "1" -Tapir = "0.2.23" +Mooncake = "0.4.19" Tracker = "0.2" Zygote = "0.6.63" julia = "1.6" @@ -79,6 +79,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/BijectorsTapirExt.jl b/ext/BijectorsMooncakeExt.jl similarity index 77% rename from ext/BijectorsTapirExt.jl rename to ext/BijectorsMooncakeExt.jl index 70805a82..d7285bf6 100644 --- a/ext/BijectorsTapirExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -1,10 +1,11 @@ -module BijectorsTapirExt +module BijectorsMooncakeExt if isdefined(Base, :get_extension) - using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule + using Mooncake: + @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule using Bijectors: find_alpha, ChainRulesCore else - using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule + using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule using ..Bijectors: find_alpha, ChainRulesCore end @@ -19,20 +20,20 @@ end # unusual Integer type is encountered. @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) -function Tapir.rrule!!( +function Mooncake.rrule!!( ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} ) where {P<:Base.IEEEFloat,I<:Integer} # Require that the integer is non-differentiable. - if tangent_type(I) != Tapir.NoTangent + if tangent_type(I) != Mooncake.NoTangent msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." throw(ArgumentError(msg)) end out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z)) function find_alpha_pb(dout::P) _, dx, dy, _ = pb(dout) - return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData() + return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData() end - return Tapir.zero_fcodual(out), find_alpha_pb + return Mooncake.zero_fcodual(out), find_alpha_pb end end diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index bcdb9523..a2c13df1 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -27,9 +27,9 @@ end test_frule(Bijectors.find_alpha, x, y, z) test_rrule(Bijectors.find_alpha, x, y, z) - if @isdefined Tapir + if @isdefined Mooncake rng = Xoshiro(123456) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -37,9 +37,9 @@ end z; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -47,9 +47,9 @@ end 3; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -57,7 +57,7 @@ end UInt32(3); is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 3e21e693..2e709491 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) b in ( :ForwardDiff, :Zygote, - :Tapir, + :Mooncake, :ReverseDiff, :Enzyme, :EnzymeForward, @@ -78,27 +78,39 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10" - rule = Tapir.build_rrule(f, x; safety_on=false) - if :tapir in broken - @test_broken( - isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], - finitediff; - rtol=rtol, - atol=atol, - ) - ) - else - @test( - isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], - finitediff; - rtol=rtol, - atol=atol, - ) - ) + if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10" + try + Mooncake.build_rrule(f, x) + catch exc + # TODO(penelopeysm): + # @test_throws AssertionError (expr...) doesn't work, unclear why + @test exc isa AssertionError end + # TODO: The above @test_throws happens because of + # https://github.com/compintell/Mooncake.jl/issues/319. If that test + # fails, it probably means that the issue was fixed, in which case + # we can remove that block and uncomment the following instead. + + # rule = Mooncake.build_rrule(f, x) + # if :Mooncake in broken + # @test_broken ( + # isapprox( + # Mooncake.value_and_gradient!!(rule, f, x)[2][2], + # finitediff; + # rtol=rtol, + # atol=atol, + # ) + # ) + # else + # @test( + # isapprox( + # Mooncake.value_and_gradient!!(rule, f, x)[2][2], + # finitediff; + # rtol=rtol, + # atol=atol, + # ) + # ) + # end end return nothing diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index b2115fe2..60354005 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -127,12 +127,12 @@ end end end # Check that the quantiles are reasonable, i.e. within - # 5 standard errors of the true quantiles (and that the MCSE is + # 6 standard errors of the true quantiles (and that the MCSE is # not too large). for i in 1:k for j in 1:length(qts) @test qs_mcse[i, j] < abs(qs_true[i, end] - qs_true[i, 1]) / 2 - @test abs(qs[i, j] - qs_true[i, j]) < 5 * qs_mcse[i, j] + @test abs(qs[i, j] - qs_true[i, j]) < 6 * qs_mcse[i, j] end end end diff --git a/test/runtests.jl b/test/runtests.jl index 914c0e32..638bd15c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,12 +34,12 @@ if VERSION < v"1.9" using Compat: stack end -# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing +# Sadly, Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're testing # on at least version 1.10. if VERSION >= v"1.10" using Pkg - Pkg.add("Tapir") - using Tapir + Pkg.add("Mooncake") + using Mooncake end const GROUP = get(ENV, "GROUP", "All")