Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Force context scaling and quantization in float32, add assertions to …
…tests (#197) *Issue #, if available:* Fixes #193 *Description of changes:* Passing in contexts in lower precision than float32 may result in a drop of accuracy. This change ensures that the tokenizer (which does scaling and quantization) operates on a float32 batch. Tested across GPU/CPU and different context dtypes with ```python from itertools import product import pandas as pd import torch from chronos import ChronosPipeline import matplotlib.pyplot as plt # requires: pip install matplotlib import numpy as np df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv") for context_dtype, context_device, model_dtype, model_device in product( [torch.bfloat16, torch.float16, torch.float32], ["cpu"], # only cpu input supported at the moment [torch.bfloat16, torch.float16, torch.float32], ["cpu", "cuda"], ): pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-tiny", device_map=model_device, torch_dtype=model_dtype, ) forecast = pipeline.predict( context=torch.tensor(df["#Passengers"]).to(dtype=context_dtype, device=context_device), prediction_length=65, num_samples=20, limit_prediction_length=False, ) assert forecast.dtype == context_dtype, f"{forecast.dtype=} but {context_dtype=}" assert str(forecast.device) == context_device, f"{forecast.device=} but {context_device=}" forecast_index = range(len(df), len(df) + 65) low, median, high = np.quantile(forecast[0].to(device="cpu", dtype=torch.float32).numpy(), [0.1, 0.5, 0.9], axis=0) plt.figure(figsize=(8, 4)) plt.plot(df["#Passengers"], color="royalblue", label="historical data") plt.plot(forecast_index, median, color="tomato", label="median forecast") plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval") plt.legend() plt.grid() plt.show() ``` By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
- Loading branch information