-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR, Relay] improve bfloat16 support #2
Conversation
.gitignore
Outdated
@@ -11,7 +11,10 @@ __pycache__/ | |||
.Python | |||
env/ | |||
build/ | |||
build_debug/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't change this. You can change it locally, but don't upsteam.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I'll fix this.
include/tvm/tir/op.h
Outdated
static const Op& op = Op::Get("tir." #OpName); \ | ||
if (x.dtype().is_bfloat16()) { \ | ||
DataType srcType = x.dtype(); \ | ||
DataType dstType(kDLFloat, 32, srcType.lanes()); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make those \
in a row.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@@ -40,6 +40,8 @@ | |||
"nn.conv3d_transpose", | |||
"nn.dense", | |||
"nn.batch_matmul", | |||
"nn.bias_add", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we can change this default list. Better to have another CPU list, otherwise you need to evaluate the impact to NV hardware.
@@ -126,3 +155,4 @@ def test_fp16_conversion(target, dev): | |||
test_basic_build() | |||
test_fp16_build() | |||
test_fp16_conversion() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add test_bf16_conversion
as fp16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! |
Thanks for reviewing, this PR has been merged to the official repo. |
* Revert "[skip ci] Revert "[ci] Default to n=2 for test parallelism (apache#12376)" (apache#12413)" This reverts commit 478b672. * [ci] Default to n=2 for test parallelism This is attempt #2 of apache#12376 which was reverted in apache#12413. The changes in `plugin.py` should keep all the tests on the same node so sporadic failures don't happen due to scheduling. Co-authored-by: driazati <driazati@users.noreply.github.com>
Motivation:
We are enabling bfloat16 in BYOC-oneDNN following the path: [float32 graph] --> <AMP> --> [bfloat16 graph] --> <BYOC> --> [TVM + oneDNN module]. While some of the Passes like
FoldConstant
can not work for bfloat16 before the improvements below.Changes:
With those improvements, a float32 graph could be converted to bfloat16 through AMP, and then be lowered to inference in bfloat16 mode now.
Tested Models (gluoncv):
As @AndrewZhaoLuo said at apache#8069
Pending:
The support for bfloat16 in BYOC-oneDNN is based on multi-blocking layout transform and the extensions on BYOC-oneDNN and pending.