-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Force context scaling and quantization in float32, add assertions to tests #197
Conversation
LGTM. One final request: Can you please check if the eval script runs without failure? python evaluation/evaluate.py evaluation/configs/zero-shot.yaml evaluation/results/chronos-t5-small-zero-shot.csv \
--chronos-model-id "amazon/chronos-t5-small" \
--batch-size=32 \
--device=cuda:0 \
--num-samples 20 # maybe use fewer samples here for quicker eval |
@abdulfatir script runs fine. Left MASE/WQL is before, right MASE/WQL is after:
|
i am curious , why the torch_dtype is torch.bfloat16. I run this code on NVIDIA GeForce GTX TITAN X, which raise error:
while i change torch.bfloat16 into torch.float16, the error disappear. I wonder why the author set torch_dtype=torch.bfloat16 initially ? |
@XiaoJia849 bfloat16 is a 16 bit floating point format, with the same exponent range as float32 (uses the same number of bits for the exponent), but obviously less precision. In our experience, it works just as well as float32 for inference, with half the memory requirement. Not all GPUs support it though. |
Issue #, if available: Fixes #193
Description of changes: Passing in contexts in lower precision than float32 may result in a drop of accuracy. This change ensures that the tokenizer (which does scaling and quantization) operates on a float32 batch.
Tested across GPU/CPU and different context dtypes with
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.