-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[operator] Integrate oneDNN layer normalization implementation #19562
Conversation
Hey @bartekkuncer , Thanks for submitting the PR
CI supported jobs: [windows-gpu, sanity, centos-gpu, miscellaneous, website, unix-gpu, centos-cpu, clang, windows-cpu, unix-cpu, edge] Note: |
439e314
to
c4b6bce
Compare
@bartekkuncer could you add more description in the PR to avoid confusion for the reviewers? |
@pengzhao-intel was planning to add them as soon as I fix tests :) |
Hi, can we compare #19601 ? |
@kpuatamazon Sorry for the late response, was waiting for the layer norm optimization in oneDNN. Below are the results I got using marian and my onednn implementation. It looks like oneDNN is faster in most cases. Can you tell me what CPU are you using?
I built mxnet using: |
I've been using a c5.12xlarge We should at least do Might as well reshape to two dimensions with the axis preserved and everything else multiplied. The problem is identical for e.g. 100x28x10x10x10 and 280000x10. Also, those are some really small channels to layer normalize over. Also, I feel like the optimal assembly implementation would benefit from a different ordering of the input tensor to allow for pure vertical adds whereas layer normalization is currently setup for horizontal adds. I can certainly see how a JIT will do better at e.g. 1000x3 where multiple problems share the same vector. But oddly that's where marian is doing better. |
Speaking up as a 'customer' of LayerNorm here: Sockeye (and its Transformer models) cares about smaller matrix sizes for LayerNorm, i.e. typically in ranges around |
I made the sizing more systematic. AVX512 means Inverse means the 1.f/std is computed in advance rather than dividing by std in the loop. Overall, the Marian implementation seems to win on smaller problem sizes, including the x512 sizes from @fhieber but lose on larger problem sizes. Of course there are edge cases when the width is not a multiple of 16 and gcc is testing for those edge cases every time, so I see how that could be optimized.
|
@kpuatamazon which version of oneDNN have you used for the benchmark? There is a change boosting the perf of layer normalization which is going to be included in oneDNN v2.1 release. |
I used whatever was in your pull request: c4b6bce |
4fe9420
to
6d0428c
Compare
Jenkins CI successfully triggered : [centos-gpu] |
@mxnet-bot run ci [website] |
Jenkins CI successfully triggered : [website] |
…e#19562) * [operator] Integrate oneDNN layer normalization implementation * change sizeof(float) to mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype()) * remove eps from key and unify layernorm_fwd_t/mkldnn::layer_normalization_forward * add author
…e#19562) * [operator] Integrate oneDNN layer normalization implementation * change sizeof(float) to mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype()) * remove eps from key and unify layernorm_fwd_t/mkldnn::layer_normalization_forward * add author
Description
The change integrates oneDNNs implementation of forward and backward propagation of Layer Normalization for axis == -1 (default case - last axis).
Comments
As oneDNNs LayerNorm primitive does not support axis parameter (https://oneapi-src.github.io/oneDNN/dev_guide_layer_normalization.html) I had to modify input data by adjusting tensors in mxnet before sending them to oneDNN to make it work with axis != -1. I tried two approaches:
Both approaches turned out to be significantly slower than current mxnet implementation.
OneDNNs backward propagation is significantly faster than current mxnet's implementation. Forward implementation has similar performance to mxnet's generic version - depending on shape at times faster is marian and at times faster is oneDNN. As the difference in performance is significant in some of these cases I introduced simple heuristics (based on huge amount of benchmarking) for checking if layer normalization should be computed by oneDNN:
The above function can be found in mkldnn_layer_norm.cc file.
Most recent performance numbers
ln_opperf1908clx.xlsx