-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[FEATURE] Integrate oneDNN binary primitive support for forward add, subtract, multiply, divide. #20713
Conversation
Hey @agrabows , Thanks for submitting the PR
CI supported jobs: [centos-gpu, sanity, centos-cpu, windows-gpu, edge, website, miscellaneous, unix-cpu, clang, unix-gpu, windows-cpu] Note: |
3cec78d
to
34ff9eb
Compare
34ff9eb
to
1e757b1
Compare
1e757b1
to
5c961f0
Compare
5c961f0
to
64e929d
Compare
64e929d
to
5c8bbf6
Compare
5c8bbf6
to
32c8c25
Compare
@mxnet-bot run ci [windows-gpu] |
Jenkins CI successfully triggered : [windows-gpu] |
@mxnet-bot run ci [windows-gpu] |
Jenkins CI successfully triggered : [windows-gpu] |
@mxnet-bot run ci [windows-gpu] |
Jenkins CI successfully triggered : [windows-gpu] |
@mxnet-bot run ci [windows-gpu] |
Jenkins CI successfully triggered : [windows-gpu] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Leaving it open for a bit for other reviewers to take a look at the revisions.
auto ndim_1 = inputs[1].shape().ndim(); | ||
return ndim_0 >= 1 && ndim_0 <= 6 && ndim_1 >= 1 && ndim_1 <= 6 && | ||
inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 && | ||
dtype == mshadow::kFloat32 && dtype == inputs[1].dtype(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please check if oneDNN supports bfloat, if yes please create separate PR for it.
Description
Binary broadcast operators such as add, subtract, multiply, divide are implemented in both NDArray and NumPy modules and no oneDNN support exists for those operators. Goal of this task was to dispatch execution of those operators to oneDNN binary primitive.
Checklist
Essentials
Changes
Comments
Speedup for all cases noticed, up to ~350%.