Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataType] Add bfloat16 #5601

Merged
merged 47 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
6ca0e30
add bf16
Menooker Apr 30, 2020
3fba684
add bf16 in DataType (py)
Menooker May 6, 2020
48e7e94
ndarray of bf16
Menooker May 7, 2020
17ef57b
do not cast back for compare op
Menooker May 7, 2020
96f3019
const gen
Menooker May 7, 2020
4aeff41
more precise
Menooker May 7, 2020
c551a3d
update test
Menooker May 8, 2020
0510af9
enable vectorization
Menooker May 9, 2020
3c5c0f4
correct vectorize
Menooker May 9, 2020
c978b9e
linter changes
Menooker May 15, 2020
ef6f410
linter
Menooker May 15, 2020
92d014a
linter
Menooker May 15, 2020
e23f33c
linter
Menooker May 15, 2020
a245d41
Update bf16_legalize.cc
Menooker May 15, 2020
17c7084
Update bf16_legalize.cc
Menooker May 15, 2020
d51bf9b
Update bf16_legalize.cc
Menooker May 15, 2020
cbb1e5b
Update transform.py
Menooker May 15, 2020
680ecce
fix
Menooker May 15, 2020
b4b9d42
Update test_target_codegen_llvm.py
Menooker May 16, 2020
a899ef7
Update test_target_codegen_llvm.py
Menooker May 16, 2020
3523f00
Update transform.py
Menooker May 16, 2020
bac7247
bf16 => bfloat16
Menooker May 20, 2020
7224acf
fix linter problem
Menooker May 20, 2020
b36f5f4
TIR legalize
Menooker May 24, 2020
92968d0
pass test
Menooker May 24, 2020
177f99a
linter
Menooker May 24, 2020
27c0ab2
linter
Menooker May 24, 2020
3dd2a71
linter
Menooker May 24, 2020
c546376
fix AttrStmtNode
Menooker May 24, 2020
01906c5
msvc compile
Menooker May 24, 2020
7d99adb
Merge branch 'master' into bf16
liangfu May 29, 2020
817f302
comments and notes
Menooker May 29, 2020
ae67413
linter
Menooker May 29, 2020
bf1b747
Code style, use kDLBfloat
Menooker Jun 1, 2020
b1c1951
format
Menooker Jun 1, 2020
86bcc16
update dlpack
Menooker Jun 1, 2020
7612f9d
change back nullptr typecode
Menooker Jun 5, 2020
02990b1
Merge branch 'master' of https://github.com/apache/incubator-tvm into…
Menooker Jun 5, 2020
1b85a00
remove python runtime type for bf16
Menooker Jun 5, 2020
c0cb1ef
fix code style of RoundToNearestEven
Menooker Jun 12, 2020
b8b5b4a
Merge branch 'master' of https://github.com/apache/incubator-tvm into…
Menooker Jun 12, 2020
a6341cb
merge newest master
Menooker Jun 12, 2020
656e3e4
format
Menooker Jun 12, 2020
25c811c
pylint on test
Menooker Jun 13, 2020
af5438a
Merge branch 'master' of https://github.com/apache/incubator-tvm into…
Menooker Jun 15, 2020
7a72ea9
make it run on newest master
Menooker Jun 15, 2020
318ddc9
type code changes etc.
Menooker Jun 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class DataType {
kUInt = kDLUInt,
kFloat = kDLFloat,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kBFloat = kDLBfloat,
kCustomBegin = 129
};
/*! \brief default constructor */
Expand All @@ -72,6 +73,9 @@ class DataType {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
if (code == kBFloat) {
CHECK_EQ(bits, 16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is understandable that right now we only support bf16, but my concern is that "should we put the check here"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your concern. Any suggestions for the location where we put this check? Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen This is just a nitpick. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us leave it as it is for now, we can come back to it later

}
}
/*! \return The type code. */
int code() const { return static_cast<int>(data_.code); }
Expand All @@ -89,6 +93,8 @@ class DataType {
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
/*! \return whether type is an int type. */
bool is_int() const { return code() == DataType::kInt; }
/*! \return whether type is an uint type. */
Expand Down Expand Up @@ -283,6 +289,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
return "float";
case DataType::kHandle:
return "handle";
case kDLBfloat:
return "bfloat";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
return "";
Expand Down Expand Up @@ -349,6 +357,9 @@ inline DLDataType String2DLDataType(std::string s) {
t.bits = 1;
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "bfloat") {
t.code = DataType::kBFloat;
scan = s.c_str() + 6;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
}
}
if (t.is_float()) return FloatImm(t, static_cast<double>(value));
if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ TVM_DLL Pass CombineContextCall();
*/
TVM_DLL Pass NarrowDataType(int target_bits);

/*!
* \brief Legalize bf16 typed Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
* \return The pass.
*/
TVM_DLL Pass BF16Legalize();

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DataTypeCode(object):
UINT = 1
FLOAT = 2
HANDLE = 3
BFLOAT = 4


class DataType(ctypes.Structure):
Expand All @@ -65,7 +66,8 @@ class DataType(ctypes.Structure):
DataTypeCode.INT : 'int',
DataTypeCode.UINT : 'uint',
DataTypeCode.FLOAT : 'float',
DataTypeCode.HANDLE : 'handle'
DataTypeCode.HANDLE : 'handle',
DataTypeCode.BFLOAT : 'bfloat'
}
def __init__(self, type_str):
super(DataType, self).__init__()
Expand Down Expand Up @@ -96,6 +98,9 @@ def __init__(self, type_str):
self.type_code = DataTypeCode.HANDLE
bits = 64
head = ""
elif head.startswith("bfloat"):
self.type_code = DataTypeCode.BFLOAT
head = head[6:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
Expand Down
1 change: 1 addition & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def lower(sch,
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
]
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,56 @@ def RemoveNoOp():
"""
return _ffi_api.RemoveNoOp()

def BF16Legalize():
"""Legalize bf16 typed Ops.
Runs BF16Promote, BF16CastElimination and BF16TypeLowering

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16Legalize()

def BF16Promote():
"""Promote bf16 to fp32. Add a cast to fp32
before Ops, then add a cast back to bf16.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16Promote()

def BF16CastElimination():
"""Eliminate verbose casting between fp32 and bf16
Checks if the AST has the pattern:
castto32(castto16(some_fp32_op(...)))
The verbose casting is generated by BF16Promote for multiple
bf16 Ops in a row. e.g.:
X[i] + Y[i] + T[i] =>
bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
After this pass:
bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16CastElimination()

def BF16TypeLowering():
"""Replace all bf16 type with uint16. Also lower the casting
between fp32 and bf16

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16TypeLowering()

def RewriteUnsafeSelect():
"""Detect and rewrite unsafe select that contains memory access.
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::strin
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
// Phase 1
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::LoopPartition());
Expand Down
Loading