@@ -907,28 +907,29 @@ function link!!(
907907 x = vi[:]
908908 y, logjac = with_logabsdet_jacobian (b, x)
909909
910- # Set parameters
911- vi_new = unflatten (vi, y)
912- # Update logjac. We can overwrite any old value since there is only
913- # a single logjac term to worry about.
914- vi_new = setlogjac!! (vi_new, logjac)
915- return settrans!! (vi_new , t)
910+ # Set parameters and add the logjac term.
911+ vi = unflatten (vi, y)
912+ if hasacc (vi, Val ( :LogJacobian ))
913+ vi = acclogjac!! (vi, logjac)
914+ end
915+ return settrans!! (vi , t)
916916end
917917
918918function invlink!! (
919919 t:: StaticTransformation{<:Bijectors.Transform} , vi:: AbstractVarInfo , :: Model
920920)
921921 b = t. bijector
922922 y = vi[:]
923- x = b ( y)
923+ x, inv_logjac = with_logabsdet_jacobian (b, y)
924924
925- # Set parameters
926- vi_new = unflatten (vi, x)
927- # Reset logjac to 0.
928- if hasacc (vi_new, Val (:LogJacobian ))
929- vi_new = map_accumulator!! (zero, vi_new, Val (:LogJacobian ))
925+ # Mildly confusing: we need to _add_ the logjac of the inverse transform,
926+ # because we are trying to remove the logjac of the forward transform
927+ # that was previously accumulated when linking.
928+ vi = unflatten (vi, x)
929+ if hasacc (vi, Val (:LogJacobian ))
930+ vi = acclogjac!! (vi, inv_logjac)
930931 end
931- return settrans!! (vi_new , NoTransformation ())
932+ return settrans!! (vi , NoTransformation ())
932933end
933934
934935"""
0 commit comments