Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug-fix][XLA:CPU][oneDNN] Fix BINARY_ADD fusion to Dot #13301

Closed
wants to merge 1 commit into from

Conversation

mdfaijul
Copy link
Contributor

@mdfaijul mdfaijul commented Jun 1, 2024

This PR fixes a bug reported for JAX (#13054)

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jun 1, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jun 1, 2024
@NaiyerRizz NaiyerRizz self-requested a review June 3, 2024 04:32
Copy link
Member

@penpornk penpornk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix! Could you please explain what is causing the issue, and how this fix addresses it? Is it because of rank mismatch + wrong auto broadcasting or something?


ENTRY main {
constant.2 = f32[] constant(1e-06)
broadcast.3 = f32[1000000] broadcast(constant.2), dimensions={}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the size needs to be this big to reproduce the failure. Would 10 work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is not reproducible with a smaller size.

subtract.14 = f32[1000000,3] subtract(broadcast.8, broadcast.13)
constant.4 = f32[] constant(0)
broadcast.5 = f32[3,3] broadcast(constant.4), dimensions={}
dot.15 = f32[1000000,3] dot(subtract.14, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reduce the ops to just necessary ops that reproduce the failure? I don't think all the dots are needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug seems to be seen with this particular case.

@hawkinsp
Copy link
Member

hawkinsp commented Jun 3, 2024

I'd like to try to get this fix in in the next day or so so I can incorporate it in the next JAX release, please.

@kanvi-nervana
Copy link
Contributor

Thank you for the fix! Could you please explain what is causing the issue, and how this fix addresses it? Is it because of rank mismatch + wrong auto broadcasting or something?

oneDNN expects Matmul followed by Bias-Add followed by Binary-Add. But, here Matmul is followed by Binary-Add and then by Bias-Add which oneDNN does not support. The fix here is extending the dimensions of the Bias-Add to a Binary-Add which is supported. As seen below

Screenshot 2024-06-03 at 4 02 50 PM

Copy link
Member

@penpornk penpornk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the clarifications!

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jun 4, 2024
Imported from GitHub PR openxla/xla#13301

This PR fixes a bug reported for JAX (openxla/xla#13054)
Copybara import of the project:

--
47d5bde8eab607d0fe9b60c6fd82d95365c8169f by mdfaijul <md.faijul.amin@intel.com>:

Make addend rank same to dot.

Merging this change closes #13301

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13301 from Intel-tensorflow:amin/bug-fix-jax 47d5bde8eab607d0fe9b60c6fd82d95365c8169f
PiperOrigin-RevId: 640081553
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jun 4, 2024
Imported from GitHub PR openxla/xla#13301

This PR fixes a bug reported for JAX (openxla/xla#13054)
Copybara import of the project:

--
47d5bde8eab607d0fe9b60c6fd82d95365c8169f by mdfaijul <md.faijul.amin@intel.com>:

Make addend rank same to dot.

Merging this change closes #13301

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13301 from Intel-tensorflow:amin/bug-fix-jax 47d5bde8eab607d0fe9b60c6fd82d95365c8169f
PiperOrigin-RevId: 640081553
copybara-service bot pushed a commit that referenced this pull request Jun 4, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=#13301 from Intel-tensorflow:amin/bug-fix-jax 47d5bde
PiperOrigin-RevId: 638276915
@copybara-service copybara-service bot closed this in 7d12719 Jun 4, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jun 4, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13301 from Intel-tensorflow:amin/bug-fix-jax 47d5bde8eab607d0fe9b60c6fd82d95365c8169f
PiperOrigin-RevId: 638276915
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jun 4, 2024
Imported from GitHub PR openxla/xla#13301

This PR fixes a bug reported for JAX (openxla/xla#13054)
Copybara import of the project:

--
47d5bde8eab607d0fe9b60c6fd82d95365c8169f by mdfaijul <md.faijul.amin@intel.com>:

Make addend rank same to dot.

Merging this change closes #13301

PiperOrigin-RevId: 640094871
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants