Skip to content

Commit

Permalink
Rename scale to decoding context
Browse files Browse the repository at this point in the history
  • Loading branch information
Abdul Fatir Ansari committed Mar 25, 2024
1 parent 4526e30 commit b1386af
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/chronos/chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _prepare_and_validate_context(
@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Any]:
"""
Get encoder embeddings for the given time series.
Expand All @@ -365,21 +365,23 @@ def embed(
Returns
-------
embeddings, scale
A tuple of two tensors: the encoder embeddings and the time series scale.
embeddings, decoding_context
A tuple of two tensors: the encoder embeddings and the decoding_context,
e.g., the scale of the time series in the case of mean scaling.
The encoder embeddings are shaped (batch_size, context_length, d_model)
or (batch_size, context_length + 1, d_model), where the extra 1 is for EOS.
If your original time series were shorter than the model's context_length,
please slice the returned embeddings along the time axis accordingly.
The scale is shaped (batch_size, ).
"""
context = self._prepare_and_validate_context(context=context)
token_ids, attention_mask, scale = self.tokenizer.input_transform(context)
token_ids, attention_mask, decoding_context = self.tokenizer.input_transform(
context
)
embeddings = self.model.encode(
input_ids=token_ids.to(self.model.device),
attention_mask=attention_mask.to(self.model.device),
).cpu()
return embeddings, scale
return embeddings, decoding_context

def predict(
self,
Expand Down

0 comments on commit b1386af

Please sign in to comment.