Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Tapir support #71

Closed
wants to merge 20 commits into from
Closed

add Tapir support #71

wants to merge 20 commits into from

Conversation

Red-Portal
Copy link
Member

No description provided.

This was referenced Jun 25, 2024
@yebai
Copy link
Member

yebai commented Aug 3, 2024

@sunxd3 can you help complete this integration?

@sunxd3
Copy link
Member

sunxd3 commented Aug 3, 2024

After I finish the work on hands, will work on this!

test/interface/ad.jl Outdated Show resolved Hide resolved
.github/workflows/CI.yml Outdated Show resolved Hide resolved
@yebai yebai force-pushed the master branch 2 times, most recently from 40fd15b to 48fc01d Compare August 9, 2024 15:05
@yebai
Copy link
Member

yebai commented Aug 9, 2024

I formatted the code on the master branch and this PR so the code changes are more readable. Unfortunately, this means some manual merge efforts are required for other PRs, like #67.

@yebai
Copy link
Member

yebai commented Aug 9, 2024

@willtebbutt It seems the only remaining issue is the lack of a rule for Intrinsics.fptrunc. Can this be quickly added?

https://github.com/TuringLang/AdvancedVI.jl/actions/runs/10321643466/job/28575032583?pr=71#step:6:1408

@willtebbutt
Copy link
Member

Just had a look at the source -- it looks like this is happening in the rule that Tapir.jl is deriving for randn!. Shoudn't be a problem to add a rule because this intrinsic seems fairly safe (it's converting a Float64 into a Float32).

@Red-Portal
Copy link
Member Author

Red-Portal commented Aug 9, 2024

I formatted the code on the master branch and this PR so the code changes are more readable. Unfortunately, this means some manual merge efforts are required for other PRs, like #67.

@yebai Thanks! I'll deal with 67 myself.

@Red-Portal
Copy link
Member Author

@willtebbutt BTW, in this PR, I had to modify the package so that we carry around the rrule. I personally don't want to do this, is there a Tapir native way? DifferentiationInterface unfortunately doesn't support auxiliary input yet, so I can't rely on it for this (unless we wait until it gets implemented.)

@willtebbutt
Copy link
Member

@willtebbutt BTW, in this PR, I had to modify the package so that we carry around the rrule. I personally don't want to do this, is there a Tapir native way? DifferentiationInterface unfortunately doesn't support auxiliary input yet, so I can't rely on it for this (unless we wait until it gets implemented.)

Alas, there is not presently -- you'll have to continue to pass it around.

@willtebbutt
Copy link
Member

@willtebbutt It seems the only remaining issue is the lack of a rule for Intrinsics.fptrunc. Can this be quickly added?

This is now resolved in v0.2.33 .

@willtebbutt
Copy link
Member

Some of the remaining problems are to do with a different intrinsic, Core.Intrinsics.fpext. A version of Tapir.jl which addresses this should be available in an hour or so.

@willtebbutt
Copy link
Member

I'm pretty sure that my comment will unblock this PR, but not everything will pass until v0.2.38 of Tapir.jl is available -- PR here

rule = st_ad
y, g = Tapir.value_and_gradient!!(rule, f, x)
DiffResults.value!(out, y)
DiffResults.gradient!(out, last(g))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DiffResults.gradient!(out, last(g))
DiffResults.gradient!(out, g[2])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt, to clarify, we don't need this change. Is that correct?

test/interface/ad.jl Outdated Show resolved Hide resolved
@yebai
Copy link
Member

yebai commented Aug 21, 2024

All tests are passing now!

@@ -19,7 +19,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.7'
#- '1.7'
- '1.10'
Copy link
Member

@yebai yebai Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt, can you adapt the Bijectors setup so we don't need to comment out 1.7?

@yebai yebai mentioned this pull request Aug 22, 2024
@yebai
Copy link
Member

yebai commented Aug 22, 2024

#67 introduces a few breaking changes to the value_and_gradient! interface, which is not yet incorporated in this PR.

@willtebbutt
Copy link
Member

@Red-Portal is your plan to continue working on this PR, or should I take a look at finishing it off now that everything seems to be working on Tapir.jl's end?

@Red-Portal
Copy link
Member Author

Hi all, sorry for the confusion caused by #67 . I was a little surprised that it broke some things because it wasn't changing the interface at all. Well turns out it is this PR that is breaking the interface (purely to carry around the rrules), but the master branch is still using the old interface, which is what is causing the conflicts.

Now that #67 is out of the way, I will be able to to spend more time on this PR. I'll take a look from now on. But I really want to do something about carrying around the rrules. Maybe I should look into Memoization?

@yebai
Copy link
Member

yebai commented Aug 22, 2024

Well turns out it is this PR that is breaking the interface (purely to carry around the rrules), but the master branch is still using the old interface, which is what is causing the conflicts.

It would be good to avoid heuristics like Memoization and pass cached rules explicitly. Can we upgrade everything to this PR's interface?

@Red-Portal
Copy link
Member Author

If Memoization works for this, we can go back to the old interface and contain the memoization stuff in Tapir's extension. I personally think this is the way to go if Memoization works since passing around rrules is just way too complicated and only necessary because of Tapir.

@yebai
Copy link
Member

yebai commented Aug 22, 2024

I don’t have good experience with memorisation tricks; they lead to subtle bugs and are often hard to reason about. Maybe @willtebbutt have thought about this?

@willtebbutt
Copy link
Member

I would also be wary of using Memoization here. At least one issue that could result in not-entirely-obvious cache invalidation is method redefinition. The global method table handles this well obviously, but I'm not sure it's entirely trivial to do yourself (hence the fact that Tapir.jl doesn't yet do it).

Re Tapir.jl being the only backend to require a preparation step: is it not the case that ReverseDiff.jl requires a cache when operating in compiled mode?

@Red-Portal
Copy link
Member Author

Hi @willtebbutt !

At least one issue that could result in not-entirely-obvious cache invalidation is method redefinition.

Would this happen? We only use global functions now and the only things that change are auxiliary inputs. So if we are thinking of the same thing, I think it should be okay.

Re Tapir.jl being the only backend to require a preparation step: is it not the case that ReverseDiff.jl requires a cache when operating in compiled mode?

I don't think we'll be able to use compiled tapes because they are incompatible with changing auxiliary inputs.

@yebai
Copy link
Member

yebai commented Aug 25, 2024

There are a few details to consider carefully. I suggest staying away from memorization heuristics; the cons often outweigh the minor efforts that we could save (from using memorization) when it comes to finding a proper solution.

I hope that @willtebbutt can take over from here to ensure a robust AD interface for Tapir and other packages.

@willtebbutt
Copy link
Member

Would this happen? We only use global functions now and the only things that change are auxiliary inputs. So if we are thinking of the same thing, I think it should be okay.

This can definitely happen -- Julia's world age system is used to track this, so if you were to do caching you would definitely need to take it into account.

@Red-Portal
Copy link
Member Author

Hi @willtebbutt @yebai , okay after seeing the discussion, I guess carrying around rrules is the best we can do for now. Then I agree with this until DifferentiationInferface catches up or similar.

@yebai
Copy link
Member

yebai commented Sep 2, 2024

Closed in favour of #86

@yebai yebai closed this Sep 2, 2024
@Red-Portal Red-Portal deleted the tapir branch September 10, 2024 04:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants