-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tfjs-core] fix gather gradient when batchDims is 1 #7942
Conversation
/gcbrun |
Thanks for the contribution! It looks like you're splitting the batched tensor, applying the gradient function to each batch, and then stacking them together again. This seems reasonable to me. The original math for There's probably a way to apply your approach for generic batchDims, but I'm not going to block this PR because of that. Also, it looks like CI is failing due to code linting. I'll push a patch to fix that. |
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thank you @mattsoulanille for the help and review. I will take note of the lint issue for future PRs. Hi @pyu10055, can I get your blessing and help to merge this? |
This PR fixes #7494, where error is encountered when computing gradient of
tf.gather
withbatchDims=1
.I am opening a PR with a naive and unoptimized fix first to see if the CI passes, and gather feedback on the direction of the fix.
If the CI passes, I will work on optimizing the logic if needed.
I am not really familiar with the maths behind the original implementation in
derX
ofgatherGradConfig
, and I couldn't get it to work within thederX
function, I am open to suggestions on better and more optimized way to fix the issue.To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.