From d2eef92009aaa53721198d4bb14b3241118deb55 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 18 Nov 2024 09:55:54 +0100 Subject: [PATCH] Force context scaling and quantization in float32, add assertions to tests (#197) *Issue #, if available:* Fixes #193 *Description of changes:* Passing in contexts in lower precision than float32 may result in a drop of accuracy. This change ensures that the tokenizer (which does scaling and quantization) operates on a float32 batch. Tested across GPU/CPU and different context dtypes with ```python from itertools import product import pandas as pd import torch from chronos import ChronosPipeline import matplotlib.pyplot as plt # requires: pip install matplotlib import numpy as np df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv") for context_dtype, context_device, model_dtype, model_device in product( [torch.bfloat16, torch.float16, torch.float32], ["cpu"], # only cpu input supported at the moment [torch.bfloat16, torch.float16, torch.float32], ["cpu", "cuda"], ): pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-tiny", device_map=model_device, torch_dtype=model_dtype, ) forecast = pipeline.predict( context=torch.tensor(df["#Passengers"]).to(dtype=context_dtype, device=context_device), prediction_length=65, num_samples=20, limit_prediction_length=False, ) assert forecast.dtype == context_dtype, f"{forecast.dtype=} but {context_dtype=}" assert str(forecast.device) == context_device, f"{forecast.device=} but {context_device=}" forecast_index = range(len(df), len(df) + 65) low, median, high = np.quantile(forecast[0].to(device="cpu", dtype=torch.float32).numpy(), [0.1, 0.5, 0.9], axis=0) plt.figure(figsize=(8, 4)) plt.plot(df["#Passengers"], color="royalblue", label="historical data") plt.plot(forecast_index, median, color="tomato", label="median forecast") plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval") plt.legend() plt.grid() plt.show() ``` By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- src/chronos/chronos.py | 8 ++++-- test/test_chronos.py | 55 ++++++++++++++++++++++++------------------ 2 files changed, 38 insertions(+), 25 deletions(-) diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index c8ba344..000dc67 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -169,6 +169,7 @@ def __init__( def _input_transform( self, context: torch.Tensor, scale: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + context = context.to(dtype=torch.float32) attention_mask = ~torch.isnan(context) if scale is None: @@ -373,7 +374,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: size=(max_len - len(c),), fill_value=torch.nan, device=c.device ) padded.append(torch.concat((padding, c), dim=-1)) - return torch.stack(padded) + return torch.stack(padded).to(tensors[0]) @dataclass @@ -506,6 +507,9 @@ def predict( raise ValueError(msg) warnings.warn(msg) + input_dtype = context_tensor.dtype + input_device = context_tensor.device + predictions = [] remaining = prediction_length @@ -536,7 +540,7 @@ def predict( [context_tensor, prediction.median(dim=1).values], dim=-1 ) - return torch.cat(predictions, dim=-1) + return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device) @classmethod def from_pretrained(cls, *args, **kwargs): diff --git a/test/test_chronos.py b/test/test_chronos.py index a84c57d..e2b71dc 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -157,24 +157,26 @@ def test_tokenizer_random_data(use_eos_token: bool): assert samples.shape == (2, 10, 4) -def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None: - assert isinstance(samples, torch.Tensor) - assert samples.shape == shape +def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None: + assert isinstance(a, torch.Tensor) + assert a.shape == shape + assert a.dtype == dtype -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline_predict(torch_dtype: str): +@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16]) +def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", device_map="cpu", - torch_dtype=torch_dtype, + torch_dtype=model_dtype, ) - context = 10 * torch.rand(size=(4, 16)) + 10 + context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10 # input: tensor of shape (batch_size, context_length) samples = pipeline.predict(context, num_samples=12, prediction_length=3) - validate_tensor(samples, (4, 12, 3)) + validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype) with pytest.raises(ValueError): samples = pipeline.predict(context, num_samples=7, prediction_length=65) @@ -182,12 +184,12 @@ def test_pipeline_predict(torch_dtype: str): samples = pipeline.predict( context, num_samples=7, prediction_length=65, limit_prediction_length=False ) - validate_tensor(samples, (4, 7, 65)) + validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype) # input: batch_size-long list of tensors of shape (context_length,) samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) - validate_tensor(samples, (4, 12, 3)) + validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype) with pytest.raises(ValueError): samples = pipeline.predict(list(context), num_samples=7, prediction_length=65) @@ -198,12 +200,12 @@ def test_pipeline_predict(torch_dtype: str): prediction_length=65, limit_prediction_length=False, ) - validate_tensor(samples, (4, 7, 65)) + validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype) # input: tensor of shape (context_length,) samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) - validate_tensor(samples, (1, 12, 3)) + validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype) with pytest.raises(ValueError): samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65) @@ -214,36 +216,43 @@ def test_pipeline_predict(torch_dtype: str): prediction_length=65, limit_prediction_length=False, ) - validate_tensor(samples, (1, 7, 65)) + validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype) -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline_embed(torch_dtype: str): +@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16]) +def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", device_map="cpu", - torch_dtype=torch_dtype, + torch_dtype=model_dtype, ) d_model = pipeline.model.model.config.d_model - context = 10 * torch.rand(size=(4, 16)) + 10 + context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10 expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0) # input: tensor of shape (batch_size, context_length) embedding, scale = pipeline.embed(context) - validate_tensor(embedding, (4, expected_embed_length, d_model)) - validate_tensor(scale, (4,)) + validate_tensor( + embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(scale, shape=(4,), dtype=torch.float32) # input: batch_size-long list of tensors of shape (context_length,) embedding, scale = pipeline.embed(list(context)) - validate_tensor(embedding, (4, expected_embed_length, d_model)) - validate_tensor(scale, (4,)) + validate_tensor( + embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(scale, shape=(4,), dtype=torch.float32) # input: tensor of shape (context_length,) embedding, scale = pipeline.embed(context[0, ...]) - validate_tensor(embedding, (1, expected_embed_length, d_model)) - validate_tensor(scale, (1,)) + validate_tensor( + embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(scale, shape=(1,), dtype=torch.float32) @pytest.mark.parametrize("n_tokens", [10, 1000, 10000])