Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve enzyme gradient ops removal in while op #167

Merged
merged 8 commits into from
Nov 13, 2024

Conversation

Pangoraw
Copy link
Collaborator

  • inclusive ranges.
  • floating point ranges.
  • complete enzyme.set/get removal by zeroing gradients.

!isExact)
return std::nullopt;

return numIters_i.getSExtValue() + inclusive;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is mathematically correct if step != 1

e.g. 2 * i < 3 can be 0 or 1, as well as 2 * i <= 3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the better way to do the change is to add 1 to the limit, then use the < code

func.func @main(%arg0: tensor<i64>) -> tensor<i64> {
%c = stablehlo.constant dense<10> : tensor<i64>
%c_0 = stablehlo.constant dense<1> : tensor<i64>
%0:2 = stablehlo.while(%iterArg = %c_0, %iterArg_1 = %arg0) : tensor<i64>, tensor<i64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so floats are weird and not necessarily commutative, so I'm not confident this will always work.

Instead could we have an optimization that changes while comparisons on floats into int's, and then we just have the int AD special case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think most of the changes in this PR should be reverted so that Enzyme-JAX's canonical representation of for loops is only in the [0,N[ range and we can canonicalize inclusive or floating point range on top of those. In the meantime, I will make the Reactant PR emit [0,N[ for loops as well. Always having [0,N[ is handy because a cache push will be a dynamic_update_slice at loop index and cache pop is dynamic_slice at N - i.

@Pangoraw Pangoraw changed the title More whileops things Improve enzyme gradient ops removal in while op Nov 11, 2024
@wsmoses wsmoses merged commit 07dc4d4 into EnzymeAD:main Nov 13, 2024
5 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants