Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tfjs-core] fix gather gradient when batchDims is 1 #7942

Merged
merged 7 commits into from
Sep 12, 2023

Conversation

paradite
Copy link
Contributor

@paradite paradite commented Sep 5, 2023

This PR fixes #7494, where error is encountered when computing gradient of tf.gather with batchDims=1.

I am opening a PR with a naive and unoptimized fix first to see if the CI passes, and gather feedback on the direction of the fix.

If the CI passes, I will work on optimizing the logic if needed.

I am not really familiar with the maths behind the original implementation in derX of gatherGradConfig, and I couldn't get it to work within the derX function, I am open to suggestions on better and more optimized way to fix the issue.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.

@paradite paradite changed the title [tfjs-core][draft] fix gather gradient when batchDims is 1 [tfjs-core] fix gather gradient when batchDims is 1 Sep 5, 2023
@mattsoulanille
Copy link
Member

/gcbrun

@mattsoulanille
Copy link
Member

Thanks for the contribution! It looks like you're splitting the batched tensor, applying the gradient function to each batch, and then stacking them together again. This seems reasonable to me.

The original math for derX is also not really clear to me. For y = tf.gather(x, indices, axis, batchDim), I'd expect dy/dx to just copy the selected values specified by the indices variable from the dy variable and be zero everywhere else. However, it seems to be using unsortedSegmentSum. I think this is to account for repeated indices. There's also some complicated transpose logic to make the input work for unsortedSegmentSum. It might be possible to get the indices of these two ops to work correctly for arbitrary batch dimensions, but I'm not sure. I'm fine with your approach, and we can revisit it if we need better performance in the future.

There's probably a way to apply your approach for generic batchDims, but I'm not going to block this PR because of that.

Also, it looks like CI is failing due to code linting. I'll push a patch to fix that.

@mattsoulanille
Copy link
Member

/gcbrun

Copy link
Member

@mattsoulanille mattsoulanille left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@paradite
Copy link
Contributor Author

Thank you @mattsoulanille for the help and review. I will take note of the lint issue for future PRs.

Hi @pyu10055, can I get your blessing and help to merge this?

@pyu10055 pyu10055 merged commit f44e224 into tensorflow:master Sep 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tf.gather causes reshape error during gradient computation
3 participants