diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 31272b5..e14f963 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -351,7 +351,7 @@ def _prepare_and_validate_context( @torch.no_grad() def embed( self, context: Union[torch.Tensor, List[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Any]: """ Get encoder embeddings for the given time series. @@ -365,21 +365,23 @@ def embed( Returns ------- - embeddings, scale - A tuple of two tensors: the encoder embeddings and the time series scale. + embeddings, decoding_context + A tuple of two tensors: the encoder embeddings and the decoding_context, + e.g., the scale of the time series in the case of mean scaling. The encoder embeddings are shaped (batch_size, context_length, d_model) or (batch_size, context_length + 1, d_model), where the extra 1 is for EOS. If your original time series were shorter than the model's context_length, please slice the returned embeddings along the time axis accordingly. - The scale is shaped (batch_size, ). """ context = self._prepare_and_validate_context(context=context) - token_ids, attention_mask, scale = self.tokenizer.input_transform(context) + token_ids, attention_mask, decoding_context = self.tokenizer.input_transform( + context + ) embeddings = self.model.encode( input_ids=token_ids.to(self.model.device), attention_mask=attention_mask.to(self.model.device), ).cpu() - return embeddings, scale + return embeddings, decoding_context def predict( self,