Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Bfloat16 support #777

Merged
merged 3 commits into from
Dec 22, 2023
Merged

Better Bfloat16 support #777

merged 3 commits into from
Dec 22, 2023

Conversation

Giuseppe5
Copy link
Collaborator

No description provided.

@Giuseppe5 Giuseppe5 changed the title Bfloat16 support in ptq evaluate Better Bfloat16 support Dec 11, 2023
@nickfraser
Copy link
Collaborator

I think this is good, but there are a few things that we may want to revisit down the line:

  1. The function name float32_kthvalue:
    • There is a slight mismatch between the function name and what actually runs, e.g., the function name implies that kthvalue will always run in float32, but I believe it will only do that when specific datatypes / devices have been encountered - a float64 input will still run the algorithm at float64
      • It may be better to simply call the function kthvalue, or change the functionality
  2. The upconversion of datatypes during QuantTensorBase._pre_round_int_value() seems to mean that if x.value is a bfloat16 tensor then the result of the x._pre_round_int_value() will be a float32 type, but otherwise the returned int_value will be the same type as the input (e.g., float16, float32 or float64)
    • I'm not saying that this is wrong, but it may have some undesirable side-effects down-the-line where downstream values are being upcast to higher precision values, especially if we want to take advantage of an accelerated bfloat16 backend. Consider the examples below*. I'm just cautious that this may cause downstream operations to be different to the ones we desire...

*

>>> import torch
>>> x = torch.rand((1,),dtype=torch.bfloat16,device="cuda:0")
>>> y = torch.rand((1,),dtype=torch.bfloat16,device="cuda:0")
>>> r = x + y # Output type == input type
>>> r.dtype
torch.bfloat16
>>> import torch
>>> x = torch.rand((1,),dtype=torch.bfloat16,device="cuda:0")
>>> y = torch.rand((1,),dtype=torch.float32,device="cuda:0")
>>> r = x + y # Implicit upcast of x
>>> r.dtype
torch.float32

@Giuseppe5
Copy link
Collaborator Author

Regarding 1, I will rename the function to match its functionality.

Regarding 2, In this current implementation, actually the output of QuantTensor.int() will always be float32 (even though the original QuantTensor was in float16, for example).
I believe the solution could be to cast the output after rounding to the original dtype (being that bfloat16, float16, or float64), basically the behaviour that you thought it is currently happening (but it's actually not).
In this way we would maintain consistency in terms of dtype during computation, also considering that the output of _pre_round_int_value is always used internally, so we just need to be sure that post rounding value has the correct dtype (when we have integer, and don't have the issue of erroneous representation of floating point values in (b)float16).

@Giuseppe5 Giuseppe5 added the next release PRs which should be merged for the next release label Dec 20, 2023
@Giuseppe5 Giuseppe5 requested review from nickfraser and volcacius and removed request for nickfraser December 21, 2023 10:18
@Giuseppe5 Giuseppe5 requested review from nickfraser and volcacius and removed request for volcacius and nickfraser December 21, 2023 13:22
@Giuseppe5 Giuseppe5 merged commit ade1036 into Xilinx:dev Dec 22, 2023
22 checks passed
@Giuseppe5 Giuseppe5 deleted the bfloat16_support branch December 22, 2023 09:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
next release PRs which should be merged for the next release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants