Skip to content

Commit

Permalink
Force context scaling and quantization in float32, add assertions to …
Browse files Browse the repository at this point in the history
…tests (#197)

*Issue #, if available:* Fixes #193

*Description of changes:* Passing in contexts in lower precision than
float32 may result in a drop of accuracy. This change ensures that the
tokenizer (which does scaling and quantization) operates on a float32
batch.

Tested across GPU/CPU and different context dtypes with

```python
from itertools import product

import pandas as pd
import torch
from chronos import ChronosPipeline

import matplotlib.pyplot as plt  # requires: pip install matplotlib
import numpy as np

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")

for context_dtype, context_device, model_dtype, model_device in product(
    [torch.bfloat16, torch.float16, torch.float32],
    ["cpu"],  # only cpu input supported at the moment
    [torch.bfloat16, torch.float16, torch.float32],
    ["cpu", "cuda"],
):
    pipeline = ChronosPipeline.from_pretrained(
        "amazon/chronos-t5-tiny",
        device_map=model_device,
        torch_dtype=model_dtype,
    )

    forecast = pipeline.predict(
        context=torch.tensor(df["#Passengers"]).to(dtype=context_dtype, device=context_device),
        prediction_length=65,
        num_samples=20,
        limit_prediction_length=False,
    )

    assert forecast.dtype == context_dtype, f"{forecast.dtype=} but {context_dtype=}"
    assert str(forecast.device) == context_device, f"{forecast.device=} but {context_device=}"

    forecast_index = range(len(df), len(df) + 65)
    low, median, high = np.quantile(forecast[0].to(device="cpu", dtype=torch.float32).numpy(), [0.1, 0.5, 0.9], axis=0)

    plt.figure(figsize=(8, 4))
    plt.plot(df["#Passengers"], color="royalblue", label="historical data")
    plt.plot(forecast_index, median, color="tomato", label="median forecast")
    plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
    plt.legend()
    plt.grid()
    plt.show()
```


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
lostella authored Nov 18, 2024
1 parent ac6ee36 commit d2eef92
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 25 deletions.
8 changes: 6 additions & 2 deletions src/chronos/chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
def _input_transform(
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
context = context.to(dtype=torch.float32)
attention_mask = ~torch.isnan(context)

if scale is None:
Expand Down Expand Up @@ -373,7 +374,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
)
padded.append(torch.concat((padding, c), dim=-1))
return torch.stack(padded)
return torch.stack(padded).to(tensors[0])


@dataclass
Expand Down Expand Up @@ -506,6 +507,9 @@ def predict(
raise ValueError(msg)
warnings.warn(msg)

input_dtype = context_tensor.dtype
input_device = context_tensor.device

predictions = []
remaining = prediction_length

Expand Down Expand Up @@ -536,7 +540,7 @@ def predict(
[context_tensor, prediction.median(dim=1).values], dim=-1
)

return torch.cat(predictions, dim=-1)
return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device)

@classmethod
def from_pretrained(cls, *args, **kwargs):
Expand Down
55 changes: 32 additions & 23 deletions test/test_chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,37 +157,39 @@ def test_tokenizer_random_data(use_eos_token: bool):
assert samples.shape == (2, 10, 4)


def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None:
assert isinstance(samples, torch.Tensor)
assert samples.shape == shape
def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None:
assert isinstance(a, torch.Tensor)
assert a.shape == shape
assert a.dtype == dtype


@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_predict(torch_dtype: str):
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=torch_dtype,
torch_dtype=model_dtype,
)
context = 10 * torch.rand(size=(4, 16)) + 10
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10

# input: tensor of shape (batch_size, context_length)

samples = pipeline.predict(context, num_samples=12, prediction_length=3)
validate_tensor(samples, (4, 12, 3))
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)

with pytest.raises(ValueError):
samples = pipeline.predict(context, num_samples=7, prediction_length=65)

samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=False
)
validate_tensor(samples, (4, 7, 65))
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)

# input: batch_size-long list of tensors of shape (context_length,)

samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
validate_tensor(samples, (4, 12, 3))
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)

with pytest.raises(ValueError):
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
Expand All @@ -198,12 +200,12 @@ def test_pipeline_predict(torch_dtype: str):
prediction_length=65,
limit_prediction_length=False,
)
validate_tensor(samples, (4, 7, 65))
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)

# input: tensor of shape (context_length,)

samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
validate_tensor(samples, (1, 12, 3))
validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype)

with pytest.raises(ValueError):
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
Expand All @@ -214,36 +216,43 @@ def test_pipeline_predict(torch_dtype: str):
prediction_length=65,
limit_prediction_length=False,
)
validate_tensor(samples, (1, 7, 65))
validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype)


@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_embed(torch_dtype: str):
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=torch_dtype,
torch_dtype=model_dtype,
)
d_model = pipeline.model.model.config.d_model
context = 10 * torch.rand(size=(4, 16)) + 10
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)

# input: tensor of shape (batch_size, context_length)

embedding, scale = pipeline.embed(context)
validate_tensor(embedding, (4, expected_embed_length, d_model))
validate_tensor(scale, (4,))
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(scale, shape=(4,), dtype=torch.float32)

# input: batch_size-long list of tensors of shape (context_length,)

embedding, scale = pipeline.embed(list(context))
validate_tensor(embedding, (4, expected_embed_length, d_model))
validate_tensor(scale, (4,))
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(scale, shape=(4,), dtype=torch.float32)

# input: tensor of shape (context_length,)
embedding, scale = pipeline.embed(context[0, ...])
validate_tensor(embedding, (1, expected_embed_length, d_model))
validate_tensor(scale, (1,))
validate_tensor(
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(scale, shape=(1,), dtype=torch.float32)


@pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
Expand Down

0 comments on commit d2eef92

Please sign in to comment.