diff --git a/.github/workflows/eval-model.yml b/.github/workflows/eval-model.yml index 95f7e27..991d186 100644 --- a/.github/workflows/eval-model.yml +++ b/.github/workflows/eval-model.yml @@ -12,7 +12,7 @@ on: - labeled # When a label is added to the PR jobs: - evaluate-and-post: + evaluate-and-print: if: contains(github.event.pull_request.labels.*.name, 'run-eval') # Only run if 'run-eval' label is added runs-on: ubuntu-latest env: @@ -33,10 +33,5 @@ jobs: - name: Run Eval Script run: python scripts/evaluation/evaluate.py ci/evaluate/backtest_config.yaml $RESULTS_CSV --chronos-model-id=amazon/chronos-bolt-small --device=cpu --torch-dtype=float32 - - name: Upload CSV - uses: actions/upload-artifact@v4 - with: - name: eval-metrics - path: ${{ env.RESULTS_CSV }} - retention-days: 1 - overwrite: true + - name: Print CSV + run: cat $RESULTS_CSV diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 4825466..ed99701 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -25,7 +25,6 @@ from .base import BaseChronosPipeline, ForecastType - logger = logging.getLogger(__file__) @@ -240,13 +239,11 @@ def _init_weights(self, module): ): module.output_layer.bias.data.zero_() - def forward( - self, - context: torch.Tensor, - mask: Optional[torch.Tensor] = None, - target: Optional[torch.Tensor] = None, - target_mask: Optional[torch.Tensor] = None, - ) -> ChronosBoltOutput: + def encode( + self, context: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> Tuple[ + torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor + ]: mask = ( mask.to(context.dtype) if mask is not None @@ -301,8 +298,21 @@ def forward( attention_mask=attention_mask, inputs_embeds=input_embeds, ) - hidden_states = encoder_outputs[0] + return encoder_outputs[0], loc_scale, input_embeds, attention_mask + + def forward( + self, + context: torch.Tensor, + mask: Optional[torch.Tensor] = None, + target: Optional[torch.Tensor] = None, + target_mask: Optional[torch.Tensor] = None, + ) -> ChronosBoltOutput: + batch_size = context.size(0) + + hidden_states, loc_scale, input_embeds, attention_mask = self.encode( + context=context, mask=mask + ) sequence_output = self.decode(input_embeds, attention_mask, hidden_states) quantile_preds_shape = ( @@ -426,6 +436,46 @@ def __init__(self, model: ChronosBoltModelForForecasting): def quantiles(self) -> List[float]: return self.model.config.chronos_config["quantiles"] + @torch.no_grad() + def embed( + self, context: Union[torch.Tensor, List[torch.Tensor]] + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Get encoder embeddings for the given time series. + + Parameters + ---------- + context + Input series. This is either a 1D tensor, or a list + of 1D tensors, or a 2D tensor whose first dimension + is batch. In the latter case, use left-padding with + ``torch.nan`` to align series of different lengths. + + Returns + ------- + embeddings, loc_scale + A tuple of two items: the encoder embeddings and the loc_scale, + i.e., the mean and std of the original time series. + The encoder embeddings are shaped (batch_size, num_patches + 1, d_model), + where num_patches is the number of patches in the time series + and the extra 1 is for the [REG] token (if used by the model). + """ + context_tensor = self._prepare_and_validate_context(context=context) + model_context_length = self.model.config.chronos_config["context_length"] + + if context_tensor.shape[-1] > model_context_length: + context_tensor = context_tensor[..., -model_context_length:] + + context_tensor = context_tensor.to( + device=self.model.device, + dtype=torch.float32, + ) + embeddings, loc_scale, *_ = self.model.encode(context=context_tensor) + return embeddings.cpu(), ( + loc_scale[0].squeeze(-1).cpu(), + loc_scale[1].squeeze(-1).cpu(), + ) + def predict( # type: ignore[override] self, context: Union[torch.Tensor, List[torch.Tensor]], diff --git a/test/test_chronos_bolt.py b/test/test_chronos_bolt.py index 4b72568..657d0df 100644 --- a/test/test_chronos_bolt.py +++ b/test/test_chronos_bolt.py @@ -132,6 +132,50 @@ def test_pipeline_predict_quantiles( validate_tensor(mean, (1, prediction_length), dtype=torch.float32) +@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) +def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): + pipeline = ChronosBoltPipeline.from_pretrained( + Path(__file__).parent / "dummy-chronos-bolt-model", + device_map="cpu", + torch_dtype=model_dtype, + ) + d_model = pipeline.model.config.d_model + context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) + + # the patch size of dummy model is 16, so only 1 patch is created + expected_embed_length = 1 + ( + 1 if pipeline.model.config.chronos_config["use_reg_token"] else 0 + ) + + # input: tensor of shape (batch_size, context_length) + + embedding, loc_scale = pipeline.embed(context) + validate_tensor( + embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32) + validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32) + + # input: batch_size-long list of tensors of shape (context_length,) + + embedding, loc_scale = pipeline.embed(list(context)) + validate_tensor( + embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32) + validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32) + + # input: tensor of shape (context_length,) + embedding, loc_scale = pipeline.embed(context[0, ...]) + validate_tensor( + embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype + ) + validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32) + validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32) + + # The following tests have been taken from # https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py # Author: Caner Turkmen