Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force context scaling and quantization in float32, add assertions to tests #197

Merged
merged 11 commits into from
Nov 18, 2024

Conversation

lostella
Copy link
Contributor

@lostella lostella commented Nov 6, 2024

Issue #, if available: Fixes #193

Description of changes: Passing in contexts in lower precision than float32 may result in a drop of accuracy. This change ensures that the tokenizer (which does scaling and quantization) operates on a float32 batch.

Tested across GPU/CPU and different context dtypes with

from itertools import product

import pandas as pd
import torch
from chronos import ChronosPipeline

import matplotlib.pyplot as plt  # requires: pip install matplotlib
import numpy as np

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")

for context_dtype, context_device, model_dtype, model_device in product(
    [torch.bfloat16, torch.float16, torch.float32],
    ["cpu"],  # only cpu input supported at the moment
    [torch.bfloat16, torch.float16, torch.float32],
    ["cpu", "cuda"],
):
    pipeline = ChronosPipeline.from_pretrained(
        "amazon/chronos-t5-tiny",
        device_map=model_device,
        torch_dtype=model_dtype,
    )

    forecast = pipeline.predict(
        context=torch.tensor(df["#Passengers"]).to(dtype=context_dtype, device=context_device),
        prediction_length=65,
        num_samples=20,
        limit_prediction_length=False,
    )

    assert forecast.dtype == context_dtype, f"{forecast.dtype=} but {context_dtype=}"
    assert str(forecast.device) == context_device, f"{forecast.device=} but {context_device=}"

    forecast_index = range(len(df), len(df) + 65)
    low, median, high = np.quantile(forecast[0].to(device="cpu", dtype=torch.float32).numpy(), [0.1, 0.5, 0.9], axis=0)

    plt.figure(figsize=(8, 4))
    plt.plot(df["#Passengers"], color="royalblue", label="historical data")
    plt.plot(forecast_index, median, color="tomato", label="median forecast")
    plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
    plt.legend()
    plt.grid()
    plt.show()

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@lostella lostella requested a review from abdulfatir November 6, 2024 13:52
@lostella lostella added the enhancement New feature or request label Nov 6, 2024
@lostella lostella marked this pull request as draft November 6, 2024 16:27
@lostella lostella marked this pull request as ready for review November 7, 2024 09:34
@lostella lostella changed the title Force context scaling and quantization in float32 Force context scaling and quantization in float32, add assertions to tests Nov 7, 2024
test/test_chronos.py Outdated Show resolved Hide resolved
@lostella lostella requested a review from abdulfatir November 8, 2024 08:53
@abdulfatir
Copy link
Contributor

LGTM. One final request:

Can you please check if the eval script runs without failure?

python evaluation/evaluate.py evaluation/configs/zero-shot.yaml evaluation/results/chronos-t5-small-zero-shot.csv \
    --chronos-model-id "amazon/chronos-t5-small" \
    --batch-size=32 \
    --device=cuda:0 \
    --num-samples 20 # maybe use fewer samples here for quicker eval

abdulfatir
abdulfatir previously approved these changes Nov 8, 2024
@lostella
Copy link
Contributor Author

lostella commented Nov 9, 2024

@abdulfatir script runs fine. Left MASE/WQL is before, right MASE/WQL is after:

|                                                              |      MASE |       WQL |      MASE |       WQL |
|:-------------------------------------------------------------|----------:|----------:|----------:|----------:|
| ('ETTh', 'amazon/chronos-t5-small')                          |  0.789532 | 0.0779498 |  0.795824 | 0.0765788 |
| ('ETTm', 'amazon/chronos-t5-small')                          |  0.69098  | 0.0580811 |  0.682193 | 0.0579403 |
| ('dominick', 'amazon/chronos-t5-small')                      |  0.808827 | 0.336292  |  0.808234 | 0.336098  |
| ('ercot', 'amazon/chronos-t5-small')                         |  0.572246 | 0.0173959 |  0.563009 | 0.0145909 |
| ('exchange_rate', 'amazon/chronos-t5-small')                 |  2.11416  | 0.0159472 |  1.9094   | 0.0122543 |
| ('m4_quarterly', 'amazon/chronos-t5-small')                  |  1.23735  | 0.08375   |  1.23714  | 0.0837687 |
| ('m4_yearly', 'amazon/chronos-t5-small')                     |  3.74151  | 0.138683  |  3.75295  | 0.138851  |
| ('m5', 'amazon/chronos-t5-small')                            |  0.936658 | 0.589073  |  0.93704  | 0.589946  |
| ('monash_australian_electricity', 'amazon/chronos-t5-small') |  1.25666  | 0.0729795 |  1.21355  | 0.0723263 |
| ('monash_car_parts', 'amazon/chronos-t5-small')              |  0.889917 | 1.03773   |  0.881996 | 1.02642   |
| ('monash_cif_2016', 'amazon/chronos-t5-small')               |  0.982842 | 0.0144385 |  0.978809 | 0.0139298 |
| ('monash_covid_deaths', 'amazon/chronos-t5-small')           | 42.4525   | 0.0659236 | 42.3755   | 0.0641955 |
| ('monash_fred_md', 'amazon/chronos-t5-small')                |  0.476279 | 0.0165834 |  0.491483 | 0.0161477 |
| ('monash_hospital', 'amazon/chronos-t5-small')               |  0.71274  | 0.057644  |  0.710601 | 0.0572101 |
| ('monash_m1_monthly', 'amazon/chronos-t5-small')             |  1.15745  | 0.136211  |  1.17341  | 0.139096  |
| ('monash_m1_quarterly', 'amazon/chronos-t5-small')           |  1.8176   | 0.113509  |  1.77551  | 0.117829  |
| ('monash_m1_yearly', 'amazon/chronos-t5-small')              |  4.75944  | 0.177472  |  4.84628  | 0.171031  |
| ('monash_m3_monthly', 'amazon/chronos-t5-small')             |  0.887199 | 0.0996982 |  0.889429 | 0.0997251 |
| ('monash_m3_quarterly', 'amazon/chronos-t5-small')           |  1.27145  | 0.0799164 |  1.27048  | 0.0798927 |
| ('monash_m3_yearly', 'amazon/chronos-t5-small')              |  3.38602  | 0.158457  |  3.36095  | 0.15862   |
| ('monash_nn5_weekly', 'amazon/chronos-t5-small')             |  0.952411 | 0.090471  |  0.965875 | 0.0925317 |
| ('monash_tourism_monthly', 'amazon/chronos-t5-small')        |  1.91375  | 0.112959  |  1.9196   | 0.111466  |
| ('monash_tourism_quarterly', 'amazon/chronos-t5-small')      |  1.75058  | 0.0690398 |  1.75528  | 0.064458  |
| ('monash_tourism_yearly', 'amazon/chronos-t5-small')         |  3.93176  | 0.202732  |  3.9776   | 0.200138  |
| ('monash_traffic', 'amazon/chronos-t5-small')                |  0.823954 | 0.259314  |  0.824532 | 0.25917   |
| ('monash_weather', 'amazon/chronos-t5-small')                |  0.852168 | 0.147845  |  0.853171 | 0.147852  |
| ('nn5', 'amazon/chronos-t5-small')                           |  0.619628 | 0.168366  |  0.618881 | 0.168871  |

@lostella lostella merged commit d2eef92 into amazon-science:main Nov 18, 2024
2 checks passed
@lostella lostella deleted the fix-dtype branch November 18, 2024 08:55
@XiaoJia849
Copy link

import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"
import pandas as pd  # requires: pip install pandas
import torch
from chronos import BaseChronosPipeline

pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-t5-tiny",  # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
    device_map="cuda",  # use "cpu" for CPU inference
    torch_dtype=torch.bfloat16,
    # torch_dtype=torch.float16,
)

df = pd.read_csv(
    "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)
quantiles, mean = pipeline.predict_quantiles(
    context=torch.tensor(df["#Passengers"]),
    prediction_length=12,
    quantile_levels=[0.1, 0.5, 0.9],
)

i am curious , why the torch_dtype is torch.bfloat16. I run this code on NVIDIA GeForce GTX TITAN X, which raise error:

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasGemmStridedBatchedEx(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, (int)num_batches, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)

while i change torch.bfloat16 into torch.float16, the error disappear. I wonder why the author set torch_dtype=torch.bfloat16 initially ?

@lostella
Copy link
Contributor Author

@XiaoJia849 bfloat16 is a 16 bit floating point format, with the same exponent range as float32 (uses the same number of bits for the exponent), but obviously less precision. In our experience, it works just as well as float32 for inference, with half the memory requirement. Not all GPUs support it though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Ensure that scaling is performed in FP32
3 participants