-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Support 3D input for MKL-DNN softmax operator #14818
Conversation
@TaoLv thanks for the PR. Is there a test for the 1D softmax and could you show the performance of MKL-DNN primitive against original implementation? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
Tests should be covered by def test_performance():
shapes = [(1024,), (96, 512), (96, 128, 128), (96, 256, 256), (1, 8, 1024, 1024)]
for sh in shapes:
a = mx.nd.random.uniform(shape=sh)
# warm up
b = mx.nd.softmax(a, axis=-1)
b.wait_to_read()
tic = time.time()
for i in range(1000):
b = mx.nd.softmax(a, axis=-1)
b.wait_to_read()
toc = time.time()
print("softmax %s, take %f ms" % (sh, (toc - tic)/1000*1000.0)) Some performance numbers as following:
This PR with MKL-DNN backend:
|
Pending on MKL-DNN update for better performance~ |
Fallback all softmax operations when axis != last dimension because they are not optimized in MKL-DNN. |
@TaoLv I have merged the MKL-DNN 0.19 and please rebase the code and see if everything is OK :) |
@TaoLv please rebase and retrigger again the CI issue is fixed now. |
Merging now :) Thanks for your contribution. |
* add 3d softmax * fix * handle req type * clean code * remove check * check axis * retrigger ci
Description
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments