-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve enzyme gradient ops removal in while op #167
Conversation
Pangoraw
commented
Nov 10, 2024
- inclusive ranges.
- floating point ranges.
- complete enzyme.set/get removal by zeroing gradients.
!isExact) | ||
return std::nullopt; | ||
|
||
return numIters_i.getSExtValue() + inclusive; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is mathematically correct if step != 1
e.g. 2 * i < 3 can be 0 or 1, as well as 2 * i <= 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the better way to do the change is to add 1 to the limit, then use the < code
test/lit_tests/unroll2.mlir
Outdated
func.func @main(%arg0: tensor<i64>) -> tensor<i64> { | ||
%c = stablehlo.constant dense<10> : tensor<i64> | ||
%c_0 = stablehlo.constant dense<1> : tensor<i64> | ||
%0:2 = stablehlo.while(%iterArg = %c_0, %iterArg_1 = %arg0) : tensor<i64>, tensor<i64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so floats are weird and not necessarily commutative, so I'm not confident this will always work.
Instead could we have an optimization that changes while comparisons on floats into int's, and then we just have the int AD special case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I think most of the changes in this PR should be reverted so that Enzyme-JAX's canonical representation of for loops is only in the [0,N[ range and we can canonicalize inclusive or floating point range on top of those. In the meantime, I will make the Reactant PR emit [0,N[ for loops as well. Always having [0,N[ is handy because a cache push will be a dynamic_update_slice at loop index and cache pop is dynamic_slice at N - i.
This reverts commit 3cc8e73.