Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pipeline.embed support for Chronos-Bolt #247

Merged
merged 4 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions .github/workflows/eval-model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
- labeled # When a label is added to the PR

jobs:
evaluate-and-post:
evaluate-and-print:
if: contains(github.event.pull_request.labels.*.name, 'run-eval') # Only run if 'run-eval' label is added
runs-on: ubuntu-latest
env:
Expand All @@ -33,10 +33,5 @@ jobs:
- name: Run Eval Script
run: python scripts/evaluation/evaluate.py ci/evaluate/backtest_config.yaml $RESULTS_CSV --chronos-model-id=amazon/chronos-bolt-small --device=cpu --torch-dtype=float32

- name: Upload CSV
uses: actions/upload-artifact@v4
with:
name: eval-metrics
path: ${{ env.RESULTS_CSV }}
retention-days: 1
overwrite: true
- name: Print CSV
run: cat $RESULTS_CSV
68 changes: 59 additions & 9 deletions src/chronos/chronos_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from .base import BaseChronosPipeline, ForecastType


logger = logging.getLogger(__file__)


Expand Down Expand Up @@ -240,13 +239,11 @@ def _init_weights(self, module):
):
module.output_layer.bias.data.zero_()

def forward(
self,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None,
target_mask: Optional[torch.Tensor] = None,
) -> ChronosBoltOutput:
def encode(
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> Tuple[
torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
]:
mask = (
mask.to(context.dtype)
if mask is not None
Expand Down Expand Up @@ -301,8 +298,21 @@ def forward(
attention_mask=attention_mask,
inputs_embeds=input_embeds,
)
hidden_states = encoder_outputs[0]

return encoder_outputs[0], loc_scale, input_embeds, attention_mask

def forward(
self,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None,
target_mask: Optional[torch.Tensor] = None,
) -> ChronosBoltOutput:
batch_size = context.size(0)

hidden_states, loc_scale, input_embeds, attention_mask = self.encode(
context=context, mask=mask
)
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)

quantile_preds_shape = (
Expand Down Expand Up @@ -426,6 +436,46 @@ def __init__(self, model: ChronosBoltModelForForecasting):
def quantiles(self) -> List[float]:
return self.model.config.chronos_config["quantiles"]

@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Get encoder embeddings for the given time series.

Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.

Returns
-------
embeddings, loc_scale
A tuple of two items: the encoder embeddings and the loc_scale,
i.e., the mean and std of the original time series.
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
where num_patches is the number of patches in the time series
and the extra 1 is for the [REG] token (if used by the model).
"""
context_tensor = self._prepare_and_validate_context(context=context)
model_context_length = self.model.config.chronos_config["context_length"]

if context_tensor.shape[-1] > model_context_length:
context_tensor = context_tensor[..., -model_context_length:]

context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
embeddings, loc_scale, *_ = self.model.encode(context=context_tensor)
return embeddings.cpu(), (
loc_scale[0].squeeze(-1).cpu(),
loc_scale[1].squeeze(-1).cpu(),
)

def predict( # type: ignore[override]
self,
context: Union[torch.Tensor, List[torch.Tensor]],
Expand Down
44 changes: 44 additions & 0 deletions test/test_chronos_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,50 @@ def test_pipeline_predict_quantiles(
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)


@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosBoltPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-bolt-model",
device_map="cpu",
torch_dtype=model_dtype,
)
d_model = pipeline.model.config.d_model
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)

# the patch size of dummy model is 16, so only 1 patch is created
expected_embed_length = 1 + (
1 if pipeline.model.config.chronos_config["use_reg_token"] else 0
)

# input: tensor of shape (batch_size, context_length)

embedding, loc_scale = pipeline.embed(context)
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)

# input: batch_size-long list of tensors of shape (context_length,)

embedding, loc_scale = pipeline.embed(list(context))
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)

# input: tensor of shape (context_length,)
embedding, loc_scale = pipeline.embed(context[0, ...])
validate_tensor(
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32)


# The following tests have been taken from
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
# Author: Caner Turkmen <atturkm@amazon.com>
Expand Down
Loading