diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index 960c12e746f3..26b34fc19c14 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -101,10 +101,14 @@ enum TypeFlag { kInt32 = 4, kInt8 = 5, kInt64 = 6, - kInt16 = 7, - kUint16 = 8, - kUint32 = 9, - kUint64 = 10, + // kBool = 7, + // 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in + // https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/base.h#L314 + kInt16 = 8, + kUint16 = 9, + kUint32 = 10, + kUint64 = 11, + kBfloat16 = 12, }; enum IndicatorRuleFlag { @@ -126,7 +130,8 @@ enum IndicatorRuleFlag { .add_enum("int8", kInt8) \ .add_enum("int16", kInt16) \ .add_enum("int32", kInt32) \ - .add_enum("int64", kInt64) + .add_enum("int64", kInt64) \ + .add_enum("bfloat16", kBfloat16) struct CastParam : public dmlc::Parameter { int dtype; diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 780696d13899..a21c527962eb 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -41,6 +41,7 @@ static int GetDTypeSize(int type_flag) { case kInt8: return 1; case kFloat16: + case kBfloat16: case kInt16: case kUint16: return 2;