-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DataType] Add bfloat16 #5601
[DataType] Add bfloat16 #5601
Changes from 39 commits
6ca0e30
3fba684
48e7e94
17ef57b
96f3019
4aeff41
c551a3d
0510af9
3c5c0f4
c978b9e
ef6f410
92d014a
e23f33c
a245d41
17c7084
d51bf9b
cbb1e5b
680ecce
b4b9d42
a899ef7
3523f00
bac7247
7224acf
b36f5f4
92968d0
177f99a
27c0ab2
3dd2a71
c546376
01906c5
7d99adb
817f302
ae67413
bf1b747
b1c1951
86bcc16
7612f9d
02990b1
1b85a00
c0cb1ef
b8b5b4a
a6341cb
656e3e4
25c811c
af5438a
7a72ea9
318ddc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ class DataType { | |
kUInt = kDLUInt, | ||
kFloat = kDLFloat, | ||
kHandle = TVMArgTypeCode::kTVMOpaqueHandle, | ||
kBFloat = kDLBfloat, | ||
kCustomBegin = 129 | ||
}; | ||
/*! \brief default constructor */ | ||
|
@@ -65,6 +66,9 @@ class DataType { | |
data_.code = static_cast<uint8_t>(code); | ||
data_.bits = static_cast<uint8_t>(bits); | ||
data_.lanes = static_cast<uint16_t>(lanes); | ||
if (code == kBFloat) { | ||
CHECK_EQ(bits, 16); | ||
} | ||
} | ||
/*! \return The type code. */ | ||
int code() const { return static_cast<int>(data_.code); } | ||
|
@@ -82,6 +86,8 @@ class DataType { | |
bool is_float() const { return code() == DataType::kFloat; } | ||
/*! \return whether type is a float16 type. */ | ||
bool is_float16() const { return is_float() && bits() == 16; } | ||
/*! \return whether type is a bfloat16 type. */ | ||
bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } | ||
/*! \return whether type is an int type. */ | ||
bool is_int() const { return code() == DataType::kInt; } | ||
/*! \return whether type is an uint type. */ | ||
|
@@ -276,6 +282,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { | |
return "float"; | ||
case DataType::kHandle: | ||
return "handle"; | ||
case kDLBfloat: | ||
return "bfloat"; | ||
default: | ||
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code); | ||
return ""; | ||
|
@@ -342,6 +350,9 @@ inline DLDataType String2DLDataType(std::string s) { | |
t.bits = 1; | ||
t.lanes = 1; | ||
return t; | ||
} else if (s.substr(0, 6) == "bfloat") { | ||
t.code = kDLBfloat; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the type code should directly use DataType::kBfloat There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with tq There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
scan = s.c_str() + 6; | ||
} else if (s.substr(0, 6) == "custom") { | ||
t.code = ParseCustomDatatype(s, &scan); | ||
} else { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ cdef enum TVMArgTypeCode: | |
kFloat = 2 | ||
kTVMOpaqueHandle = 3 | ||
kTVMNullptr = 4 | ||
kBFloat = 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, we want to keep the type code for the normal argument as it is, and not changin the FFI There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we remove this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
kTVMDataType = 5 | ||
kTVMContext = 6 | ||
kTVMDLTensorHandle = 7 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,7 @@ class DataTypeCode(object): | |
UINT = 1 | ||
FLOAT = 2 | ||
HANDLE = 3 | ||
BFLOAT = 4 | ||
|
||
|
||
class DataType(ctypes.Structure): | ||
|
@@ -65,7 +66,8 @@ class DataType(ctypes.Structure): | |
DataTypeCode.INT : 'int', | ||
DataTypeCode.UINT : 'uint', | ||
DataTypeCode.FLOAT : 'float', | ||
DataTypeCode.HANDLE : 'handle' | ||
DataTypeCode.HANDLE : 'handle', | ||
DataTypeCode.BFLOAT : 'bfloat' | ||
} | ||
def __init__(self, type_str): | ||
super(DataType, self).__init__() | ||
|
@@ -96,6 +98,9 @@ def __init__(self, type_str): | |
self.type_code = DataTypeCode.HANDLE | ||
bits = 64 | ||
head = "" | ||
elif head.startswith("bfloat"): | ||
self.type_code = 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if it is good to hard code here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Change to DataTypeCode. TVM refactors a lot (which is good). And when this PR was raised, all the type code here used hard codes. The other two issues you raised were also changed as required. |
||
head = head[6:] | ||
elif head.startswith("custom"): | ||
# pylint: disable=import-outside-toplevel | ||
import tvm.runtime._ffi_api | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is understandable that right now we only support bf16, but my concern is that "should we put the check here"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand your concern. Any suggestions for the location where we put this check? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen This is just a nitpick. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us leave it as it is for now, we can come back to it later