Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataType] Add bfloat16 #5601

Merged
merged 47 commits into from
Jun 19, 2020
Merged

[DataType] Add bfloat16 #5601

merged 47 commits into from
Jun 19, 2020

Conversation

Menooker
Copy link
Contributor

@Menooker Menooker commented May 15, 2020

We add bfloat16 as a new type named "bf16" in the frontend. Completed LLVM backend for generating bf16.

  • Use int16 as the storage type in LLVM
  • Add legalization to enable computations on bf16
  • Add runtime frontend support (e.g. allow converting numpy's uint16 array to bf16 NDArray)

Details on legalization

Since most of the HW has no native support for computation on bf16, we added a pass BF16Legalization to use fp32 computing bf16 data. It adds cast_to_fp32() before each Op involing bf16 operands, and use Ops of fp32 to compute. Finally, it adds a 'cast_to_bf16()' after each Op that is altered. e.g.

add(a,b) => cast16(add(cast32(a), cast32(b)))

We call this phase as "BF16Promotion". It is a sub-pass of BF16Legalization pass.

We note that this will add redundant casting. e.g.

add(a, neg(b)) => cast16(add(cast32(a), cast32(cast16(neg(cast32(b)))))

The pattern cast32(cast16(some_fp32_value)) can be simplified to some_fp32_value.

Thus, we add an optimization pass after "BF16Promotion" in BF16Legalization pass, which eliminates redundant casts.

After BF16Legalization pass, there will be no bf16 related computation in the AST, except casting between fp32 and bf16, bf16 value comparasion and assignment.

Casting between fp32 and bf16

We follow PyTorch's bf16 casting implementation.

@Menooker
Copy link
Contributor Author

@zhiics @liangfu Please help review this PR. Thanks!

@tqchen
Copy link
Member

tqchen commented May 15, 2020

cc @gussmith23 might be related to BYOD

Copy link
Member

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Menooker for the great work! The proposed changes mostly looks good. I left a few comments.

@@ -309,6 +309,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
default:
LOG(FATAL) << "do not support " << dtype;
}
} else if (dtype.is_bfloat()) {
CHECK_EQ(dtype.bits(), 16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since bfloat is assumed to be 16bit, can we keep the terminology more consistent? Since the data type is termed as bf, bf16, bfloat16, bfloat in the proposed change. Or are we going to support more data types like bfloat18 and bfloat20 in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the inclarity. I think in bfloat[X], only X=16 makes sense. But TVM's type system allows specifying the bits of a type. So here is the checking to make sure it is bf16.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good question. Will we treat TensorFloat-32 as bfloat20? If so, then bits is useful to distinguish those.

if __name__ == "__main__":
test_promote()
test_eliminate()
test_legalize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a new line at EOF, even this is test script :)

def np_bf162np_float(arr):
''' Convert a numpy array of bf16 (uint16) to a numpy array
of float'''
u32 = np.left_shift(arr.astype('uint32'), 16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to produce a potential endianness problem here?

Copy link
Contributor Author

@Menooker Menooker May 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, fp32=>bf16 casting preserves the higher-ordered bits (bits 31-16). We don't need to know whether the higher-ordered bits are stored in a larger address or a smaller address (which is the endianness), we just need to get the bits by shifting, which is well-defined - just using shifting is enough.

Reference: wiki for fp32 bit order

PyTorch's bf16 casting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure about this. I have tested the code on x86, not (yet) on other arch.

Copy link
Member

@liangfu liangfu May 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reused the following code snippet, which preserves endianness checks?

https://github.com/apache/incubator-tvm/blob/6cbda80227fc18a859c4b01f57f75abbd7a16181/3rdparty/bfloat16/bfloat16.cc#L27-L35

And it has wrapper functions below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, we don't need to care about endianness. BF16 conversions only involves getting higher-ordered bits. And the operation to get higher-ordered bits in C++/Numpy is well-defined.

@@ -906,7 +954,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
if (t.is_int()) { \
return builder_->CreateICmpS##Op(a, b); \
} else if (t.is_uint()) { \
} else if (t.is_uint() || t.is_bfloat()) { \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't comparing bfloat16 this way risky?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP32/FP64 comparasion are also bit-wise in my understanding.

? static_cast<llvm::Type*>(builder_->getInt32Ty())
: llvm::VectorType::get(builder_->getInt32Ty(), from.lanes());
auto v = builder_->CreateZExt(value, extended_type);
v = builder_->CreateShl(v, 16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential endianness problem here?

@@ -114,6 +114,7 @@ typedef enum {
kTVMNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kTVMExtReserveEnd = 64U,
kTVMBFloat = 65U,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not want BFloat to be passed as PackedFunc argument, most packedfunc argument should always be passed as double

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose TVM should support kernel generation, e.g. generating a fused "conv+bn+relu", rather than generating end-to-end model, which is the usual case. In this case, we might select some intermediate layers of the model and let TVM generate the selected layers. The layers may require bf16 as the dtype, as they are in the middle of the model.

What I want to say is that we sometimes need bf16 as the input dtype. In our usecase in Intel, we need to generate a bf16 kernel (e.g. conv+bn+relu).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such dtype is covered by allocating a DLTensor with type_code equals kBFloat, and does not need patch to the code here(needed for parameter argument passing PackedFunc).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The particular code is used when we directly pass a constant into PackedFunc, e.g. f(1.0, some_float_value). in these cases double can be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove this type from TVM runtime, we cannot pass a bf16 array to TVM via Python and users can only pass bf16 buffers via C runtime (or in some awkward way to construct a bf16 DLTensor via Python). Currently, with kTVMBFloat defined, we can:

        A = te.placeholder((32, ), dtype='bfloat16')
        B = te.placeholder((32, ), dtype='bfloat16')
        d = te.compute((32, ), lambda x: A[x] + B[x])
        sch = te.create_schedule(d.op)
        module = tvm.build(sch, [A, B, d])
        npa = np.random.rand(32).astype('float32')
        npb = np.random.rand(32).astype('float32')
        a_ = np_float2tvm_bf16(npa)
        b_ = np_float2tvm_bf16(npb)
        c_ = tvm.nd.empty((32,), 'bfloat16')
        module(a_, b_, c_)

Which is useful for testing and prototyping.

Copy link
Member

@tqchen tqchen May 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you will kTVMBFloat to support this feature. The DataType::kDLBFloat flag in the runtime::DataType should be sufficient for NDArray contents(because the runtime::DataType's type code in the NDArray contents diverges from the TVM type code above the OpaqueHandle).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I understand. will change that

@@ -81,6 +82,10 @@ class DataType {
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat type. */
bool is_bfloat() const { return code() == DataType::kBFloat; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that only bfloat16 is defined, is_bf16 is a good enough function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, changed

@@ -297,6 +302,8 @@ inline const char* TypeCode2Str(int type_code) {
return "Object";
case kTVMObjectRValueRefArg:
return "ObjectRValueRefArg";
case kTVMBFloat:
return "bf";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bfloat

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, changed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, changed

// cast operatpr
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
llvm::Type* target = DTypeToLLVMType(to);
if (value->getType() == target) return value;
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (to.is_float() && from.is_bfloat()) {
CHECK_EQ(from.bits(), 16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If LLVM does not support bfloat, then perhaps we should do the legalization as a TIR=>TIR pass as opposed to do it in LLVM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are actually doing TIR=>TIR legalization pass in TVM. See src/tir/transforms/bf16_legalize.cc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we should directly change the type to be i16 during legalization and remove special handling code for bfloat16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we cannot tell whether it is a float32 => i16 or float32 => bfloat16 casting

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There're 2 kinds of legalization:

  1. TIR->TIR. TIR has full ability to describe any bfloat16 operation after this PR. This legalization is introduced just because of hardware limitation that current hardware only provide few bfloat16 operations. One day when hardware has full instructions support with bfloat16, ideally this legalization can be skipped. So this legalization is a target dependent pass.
  2. TIR->LLVM IR. I guess this is the legalization that @tqchen mentions. Because LLVM IR doesn't natively support bfloat16 , i16 will be used to replace bfloat16. In this PR, I guess this is done within codegen_llvm, not by a particular pass.

Copy link
Member

@tqchen tqchen May 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we should legalize the cast as well in the TIR to introduce the actual impl of the cast funtions in TIR, please also refer to https://tvm.apache.org/2020/05/20/bring-your-own-datatypes for releated implemenetation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just 2 small questions.

Did you mean totally eliminating bf16 dtype in legalization pass? This will bring much more complexity in the BF16Legalize pass, because we need to check every TIR node to replace bf16 with int16. In contrast, current impl only changes computation TIR nodes. And in the codegen, the bf16 generation is quite simple, just adding another ‘else if’ in casting node and tvm dtype to llvm type converter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and I think the way processing “custom data type” that you mentioned does not fit this pr well. Actually I have already notice this feature before I wrote this bf16 feature. But it needs function calls to do lowering, which is not friendly to the codegen backend to do auto vectorization and so on. Of course you can say we can implement this cast function as an intrinsic. Yes, but more complexity is brought.

I think letting bf16 dtype live until codegen is a good idea, it makes legalization, impl of casting easier

@tqchen tqchen added the status: need RFC need RFC discussion label May 21, 2020
@tqchen
Copy link
Member

tqchen commented May 21, 2020

Given that this is a new feature that will affect quite some people, please open a new RFC thread in the discuss forum to describe the motivation and the high level design. Thank you!

@Menooker
Copy link
Contributor Author

Menooker commented Jun 5, 2020

@tqchen Thanks for the clarification. I have changed kTVMNullPtr back to 4.

@tqchen
Copy link
Member

tqchen commented Jun 9, 2020

@vinx13 @ZihengJiang @liangfu it would be great if you cam take another look. Thanks @Menooker for keep improving the PR


// implementation from
// https://github.com/pytorch/pytorch/blob/master/c10/util/BFloat16.h
inline uint16_t round_to_nearest_even(float src) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Google C Style, we cannot directly copy code from another codebase into the mainline, we would need to either put it in 3rdparty, or implement it independently

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

@tqchen
Copy link
Member

tqchen commented Jun 12, 2020

@junrushao1994 can you also take a quick look at this PR. thank you!

@tqchen tqchen self-assigned this Jun 12, 2020
Comment on lines 137 to 144
def orig1(a,b):
return lambda i: a[i]+b[i]+a[99-i]+b[99-i]
def after1(a,b):
return lambda i: to16(to32(a[i])+to32(b[i])+to32(a[99-i])+to32(b[99-i]))
def orig2(a,b):
return lambda i: a[i]*b[i]+a[99-i]*b[99-i]+a[i]
def after2(a,b):
return lambda i: to16(to32(a[i])*to32(b[i])+to32(a[99-i])*to32(b[99-i])+to32(a[i]))
Copy link
Member

@junrushao junrushao Jun 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am not so sure why the coding style here can pass pylint...Mind sending a simple fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for that. Now I have formatted and manually run pylint on this test file. BTW, the test python files are never checked in TVM CI's pylint :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops I see. That makes sense then :-)

@tqchen
Copy link
Member

tqchen commented Jun 16, 2020

@@ -72,6 +73,9 @@ class DataType {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
if (code == kBFloat) {
CHECK_EQ(bits, 16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is understandable that right now we only support bf16, but my concern is that "should we put the check here"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your concern. Any suggestions for the location where we put this check? Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen This is just a nitpick. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us leave it as it is for now, we can come back to it later

@@ -372,7 +372,7 @@ inline DLDataType String2DLDataType(std::string s) {
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "bfloat") {
t.code = kTVMBFloat;
t.code = kDLBfloat;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with tq

@@ -27,7 +27,7 @@ cdef enum TVMTypeCode:
kUInt = 1
kFloat = 2
kTVMOpaqueHandle = 3
kTVMNullptr = 4
kBFloat = 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we remove this?

@@ -96,6 +98,9 @@ def __init__(self, type_str):
self.type_code = DataTypeCode.HANDLE
bits = 64
head = ""
elif head.startswith("bfloat"):
self.type_code = 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if it is good to hard code here

Copy link
Contributor Author

@Menooker Menooker Jun 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if it is good to hard code here

Change to DataTypeCode. TVM refactors a lot (which is good). And when this PR was raised, all the type code here used hard codes.

The other two issues you raised were also changed as required.

@tqchen
Copy link
Member

tqchen commented Jun 18, 2020

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :-)

@@ -72,6 +73,9 @@ class DataType {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
if (code == kBFloat) {
CHECK_EQ(bits, 16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us leave it as it is for now, we can come back to it later

@tqchen tqchen merged commit 9b7c078 into apache:master Jun 19, 2020
@tqchen
Copy link
Member

tqchen commented Jun 19, 2020

Thanks @Menooker for being patient and keep improving the PR to maintain a high quality standard! Thanks @ZhennanQin @junrushao1994 @liangfu for helpful reviews!

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 30, 2020
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Jul 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants