Skip to content

Commit

Permalink
[DataType] Add bfloat16 (apache#5601)
Browse files Browse the repository at this point in the history
  • Loading branch information
Menooker authored and Trevor Morris committed Jun 30, 2020
1 parent 675c585 commit 3d10f79
Show file tree
Hide file tree
Showing 10 changed files with 676 additions and 2 deletions.
11 changes: 11 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class DataType {
kUInt = kDLUInt,
kFloat = kDLFloat,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kBFloat = kDLBfloat,
kCustomBegin = 129
};
/*! \brief default constructor */
Expand All @@ -72,6 +73,9 @@ class DataType {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
if (code == kBFloat) {
CHECK_EQ(bits, 16);
}
}
/*! \return The type code. */
int code() const { return static_cast<int>(data_.code); }
Expand All @@ -89,6 +93,8 @@ class DataType {
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
/*! \return whether type is an int type. */
bool is_int() const { return code() == DataType::kInt; }
/*! \return whether type is an uint type. */
Expand Down Expand Up @@ -283,6 +289,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
return "float";
case DataType::kHandle:
return "handle";
case kDLBfloat:
return "bfloat";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
return "";
Expand Down Expand Up @@ -349,6 +357,9 @@ inline DLDataType String2DLDataType(std::string s) {
t.bits = 1;
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "bfloat") {
t.code = DataType::kBFloat;
scan = s.c_str() + 6;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
}
}
if (t.is_float()) return FloatImm(t, static_cast<double>(value));
if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ TVM_DLL Pass CombineContextCall();
*/
TVM_DLL Pass NarrowDataType(int target_bits);

/*!
* \brief Legalize bf16 typed Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
* \return The pass.
*/
TVM_DLL Pass BF16Legalize();

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DataTypeCode(object):
UINT = 1
FLOAT = 2
HANDLE = 3
BFLOAT = 4


class DataType(ctypes.Structure):
Expand All @@ -65,7 +66,8 @@ class DataType(ctypes.Structure):
DataTypeCode.INT : 'int',
DataTypeCode.UINT : 'uint',
DataTypeCode.FLOAT : 'float',
DataTypeCode.HANDLE : 'handle'
DataTypeCode.HANDLE : 'handle',
DataTypeCode.BFLOAT : 'bfloat'
}
def __init__(self, type_str):
super(DataType, self).__init__()
Expand Down Expand Up @@ -96,6 +98,9 @@ def __init__(self, type_str):
self.type_code = DataTypeCode.HANDLE
bits = 64
head = ""
elif head.startswith("bfloat"):
self.type_code = DataTypeCode.BFLOAT
head = head[6:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
Expand Down
1 change: 1 addition & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def lower(sch,
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
]
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,56 @@ def RemoveNoOp():
"""
return _ffi_api.RemoveNoOp()

def BF16Legalize():
"""Legalize bf16 typed Ops.
Runs BF16Promote, BF16CastElimination and BF16TypeLowering
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16Legalize()

def BF16Promote():
"""Promote bf16 to fp32. Add a cast to fp32
before Ops, then add a cast back to bf16.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16Promote()

def BF16CastElimination():
"""Eliminate verbose casting between fp32 and bf16
Checks if the AST has the pattern:
castto32(castto16(some_fp32_op(...)))
The verbose casting is generated by BF16Promote for multiple
bf16 Ops in a row. e.g.:
X[i] + Y[i] + T[i] =>
bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
After this pass:
bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16CastElimination()

def BF16TypeLowering():
"""Replace all bf16 type with uint16. Also lower the casting
between fp32 and bf16
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16TypeLowering()

def RewriteUnsafeSelect():
"""Detect and rewrite unsafe select that contains memory access.
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::strin
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
// Phase 1
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::LoopPartition());
Expand Down
Loading

0 comments on commit 3d10f79

Please sign in to comment.