Skip to content

Commit

Permalink
Add support for causal models (#113)
Browse files Browse the repository at this point in the history
*Description of changes:* This PR adds support for training
causal/decoder-only models.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
  • Loading branch information
abdulfatir and Abdul Fatir Ansari authored Jun 13, 2024
1 parent 79028e3 commit 2f92a12
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 4 deletions.
3 changes: 3 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
The output and checkpoints will be saved in `output/run-{id}/`.
> [!TIP]
> If the initial training step is too slow, you might want to change the `shuffle_buffer_length` and/or set `torch_compile` to `false`.
> [!IMPORTANT]
> When pretraining causal models (such as GPT2), the training script does [`LastValueImputation`](https://github.com/awslabs/gluonts/blob/f0f2266d520cb980f4c1ce18c28b003ad5cd2599/src/gluonts/transform/feature.py#L103) for missing values by default. If you pretrain causal models, please ensure that missing values are imputed similarly before passing the context tensor to `ChronosPipeline.predict()` for accurate results.
- (Optional) Once trained, you can easily push your fine-tuned model to HuggingFace🤗 Hub. Before that, do not forget to [create an access token](https://huggingface.co/settings/tokens) with **write permissions** and put it in `~/.cache/huggingface/token`. Here's a snippet that will push a fine-tuned model to HuggingFace🤗 Hub at `<your_hf_username>/chronos-t5-small-fine-tuned`.
```py
from chronos import ChronosPipeline
Expand Down
35 changes: 35 additions & 0 deletions scripts/training/configs/chronos-gpt2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
training_data_paths:
- "/home/ubuntu/tsmixup-data.arrow"
- "/home/ubuntu/kernelsynth-data.arrow"
probability:
- 0.9
- 0.1
context_length: 512
prediction_length: 64
min_past: 60
max_steps: 200_000
save_steps: 100_000
log_steps: 500
per_device_train_batch_size: 32
learning_rate: 0.001
optim: adamw_torch_fused
num_samples: 20
shuffle_buffer_length: 100_000
gradient_accumulation_steps: 1
model_id: openai-community/gpt2
model_type: causal
random_init: false
tie_embeddings: false
output_dir: ./output/
tf32: true
torch_compile: true
tokenizer_class: "MeanScaleUniformBins"
tokenizer_kwargs:
low_limit: -15.0
high_limit: 15.0
n_tokens: 4096
lr_scheduler_type: linear
warmup_ratio: 0.0
dataloader_num_workers: 1
max_missing_prop: 0.1
use_eos_token: true
60 changes: 57 additions & 3 deletions scripts/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
ValidationSplitSampler,
InstanceSplitter,
ExpectedNumInstanceSampler,
MissingValueImputation,
LeavesMissingValues,
LastValueImputation,
)

from chronos import ChronosConfig, ChronosTokenizer
Expand Down Expand Up @@ -301,13 +304,16 @@ def __init__(
prediction_length: int = 64,
drop_prob: float = 0.2,
min_past: Optional[int] = None,
model_type: str = "seq2seq",
imputation_method: Optional[MissingValueImputation] = None,
mode: str = "training",
np_dtype=np.float32,
) -> None:
super().__init__()

assert len(probabilities) == len(datasets)
assert mode in ("training", "validation", "test")
assert model_type in ("seq2seq", "causal")

self.datasets = datasets
self.probabilities = probabilities
Expand All @@ -316,6 +322,8 @@ def __init__(
self.prediction_length = prediction_length
self.drop_prob = drop_prob
self.min_past = min_past or prediction_length
self.model_type = model_type
self.imputation_method = imputation_method or LeavesMissingValues()
self.mode = mode
self.np_dtype = np_dtype

Expand All @@ -324,6 +332,11 @@ def preprocess_entry(self, entry: dict, mode: str) -> dict:
entry["target"] = np.asarray(entry["target"], dtype=self.np_dtype)
assert entry["target"].ndim == 1, f"got {entry['target'].ndim=}, expected 1"

if self.model_type == "causal":
# Causal models do not play nice with missing values, so it is
# recommended to use an imputation method, e.g., LastValueImputation
entry["target"] = self.imputation_method(entry["target"])

if mode == "training" and self.drop_prob > 0:
target = entry["target"].copy()
drop_p = np.random.uniform(low=0.0, high=self.drop_prob)
Expand Down Expand Up @@ -386,6 +399,48 @@ def to_hf_format(self, entry: dict) -> dict:
future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
labels[labels_mask == 0] = -100

if self.model_type == "causal":
# The InstanceSplitter pads time series on the left to be equal to the
# context_length. However, certain models (e.g., GPT2) with absolute
# position embeddings should not be trained with left padding.
# The following piece of code moves padding from left to right.

assert input_ids.shape[-1] == entry["past_is_pad"].shape[0]

# Find the index where padding starts
pad_start_idx = np.searchsorted(1 - entry["past_is_pad"], 1)
padded_input_ids, obs_input_ids = torch.tensor_split(
input_ids, [pad_start_idx], dim=-1
)
padded_attention_mask, obs_attention_mask = torch.tensor_split(
attention_mask, [pad_start_idx], dim=-1
)

# Move padding to the right
input_ids = torch.cat(
[
obs_input_ids,
labels,
padded_input_ids,
],
axis=-1,
)
attention_mask = torch.cat(
[
obs_attention_mask,
labels_mask,
padded_attention_mask,
],
axis=-1,
)

# labels for causal models are same as the input_ids.
# Internally transformers shifts the labels by one during training.
labels = input_ids.clone()
input_ids[~attention_mask] = self.tokenizer.config.pad_token_id
labels[~attention_mask] = -100

return {
"input_ids": input_ids.squeeze(0),
"attention_mask": attention_mask.squeeze(0),
Expand Down Expand Up @@ -520,9 +575,6 @@ def main(

assert model_type in ["seq2seq", "causal"]

if not model_type == "seq2seq":
raise NotImplementedError("Only seq2seq models are currently supported")

output_dir = get_next_path("run", base_dir=output_dir, file_type="")

log_on_main(f"Logging dir: {output_dir}", logger)
Expand Down Expand Up @@ -588,6 +640,8 @@ def main(
context_length=context_length,
prediction_length=prediction_length,
min_past=min_past,
model_type=model_type,
imputation_method=LastValueImputation() if model_type == "causal" else None,
mode="training",
).shuffle(shuffle_buffer_length=shuffle_buffer_length)

Expand Down
2 changes: 1 addition & 1 deletion src/chronos/chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def from_pretrained(cls, *args, **kwargs):
if chronos_config.model_type == "seq2seq":
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
else:
assert config.model_type == "causal"
assert chronos_config.model_type == "causal"
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)

return cls(
Expand Down

0 comments on commit 2f92a12

Please sign in to comment.