Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resume_from_checkpoint function fails because "There seems to be not a single sample in your epoch_iterator" #26413

Closed
2 of 4 tasks
omermazig opened this issue Sep 26, 2023 · 23 comments · May be fixed by #33544
Closed
2 of 4 tasks
Labels

Comments

@omermazig
Copy link

System Info

transformers version - 4.33.2

I'm using the trainer api as such, so it pushes the latest checkpoint to huggingface each epoch:

from transformers import TrainingArguments, Trainer

new_model_name = "videomae-finetuned"
num_epochs = 50
batch_size = 8
steps_per_epoch = train_dataset.num_videos // batch_size

args = TrainingArguments(
    output_dir=new_model_name,
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit = 2, # Only last 2 models are saved. Older ones are deleted.
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    logging_steps=10,
    max_steps=steps_per_epoch * num_epochs, # Duplication of `num_train_epochs` because it throws otherwise.
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    hub_strategy="checkpoint",
    push_to_hub=True,
    num_train_epochs=num_epochs,
)
from transformers import EarlyStoppingCallback

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=10, early_stopping_threshold=0.01)]
)
import traceback

try:
    results = trainer.train()
except RuntimeError as e:
    print(traceback.format_exc())

And after about 25 epochs there's some exception (never mind what). So I get the last checkpoint being saved to huggingface (from here, if it matters) and put it on my drive, change the training code to this:

import traceback

try:
    results = trainer.train(resume_from_checkpoint=pathlib.Path(f"./drive/MyDrive/").joinpath("last-checkpoint"))
except RuntimeError as e:
    print(traceback.format_exc())

And rerun the whole notebook. Than, it prints (after some time - not immidiatlly):

There seems to be not a single sample in your epoch_iterator, stopping training at step 5500! This is expected if you're using an IterableDataset and set num_steps (12500) higher than the number of available samples.

And than fails.

I do have an IterableDataset with 2000 training videos, and I'm using batch size 8 and want to run for 50 epochs, so I'm pretty sure 12500 is (2000/8)*50, but I still don't understand the message. Why is it problematic that num_steps (12500) > number of samples (2000)?

Thank you!

Who can help?

@muellerzr
@pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Can't really for my code, but it is based on your guide and I believe will reproduce for this as well.

Expected behavior

Continuing the training from the same state it stopped before.

@omermazig
Copy link
Author

omermazig commented Sep 26, 2023

Update - I added ignore_data_skip=True to TrainingArguments, and it was succesfull in running a single epoch, and then failed with:

ValueError: 'videomae-finetuned/checkpoint-3000' is not in list

Checkpoint 3000 is my best checkpoint (according to my metric_for_best_model), so I'm assuming that I have to have both the last checkpoint AND the best checkpoint available in the output dir, for this to work? If so, the documentation for hub_strategy is mistaken, because it stated:

"checkpoint": like "every_save" but the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to resume training easily with trainer.train(resume_from_checkpoint="last-checkpoint").

Which is wrong.

Am I missing something?

@LysandreJik
Copy link
Member

cc @pacman100 @muellerzr

@K-Niu
Copy link

K-Niu commented Oct 20, 2023

Similar question: for the sake of reproducibility, I would like to be able to resume training from the same batch where I left off in my IterableDataset (so I don't want to set ignore_data_skip=True). However, it appears that the training loop relies on the train_dataloader length to compute information needed in the resumption logic.

Is there anyway to achieve this behavior? Thanks!

@huggingface huggingface deleted a comment from github-actions bot Nov 13, 2023
@Ubadub
Copy link
Contributor

Ubadub commented Nov 24, 2023

Is there any update from the team on the issues raised above? These issues make it prohibitively expensive or practically impossible to make use of an IterableDataset in certain contexts (e.g. preemptible runs).

Alternatively, any advice on working with large datasets without using an IterableDataset? Due to issue #8818, which was mistakenly closed due to being stale but without actually being resolved, when using a regular dataset, you are essentially forced to use an IterableDataset. Perhaps there is a workaround I am not aware of.

@Ubadub
Copy link
Contributor

Ubadub commented Nov 24, 2023

CCing again: @muellerzr @pacman100

@pacman100
Copy link
Contributor

Hello, your iterable dataset should reiterate when reaching the end if the number of steps> number of samples in the iterable dataset. Best example of this is the ConstantLengthDataset from trl library. The main code snippet is given below when infinite=True setting which enables number of steps to be greater than the number of samples in the iterable dataset.

try:
      buffer.append(self.formatting_func(next(iterator)))
      buffer_len += len(buffer[-1])
except StopIteration:
  if self.infinite:
      iterator = iter(self.dataset)
      warnings.warn("The dataset reached end and the iterator is reset to the start.")
  else:
      more_examples = False
      break

Notice the logic in exception handling to reassign iterator =iter(self.dataset) and the corresponding warning "The dataset reached end and the iterator is reset to the start."

Hope this helps.

@pacman100
Copy link
Contributor

ValueError: 'videomae-finetuned/checkpoint-3000' is not in list

Checkpoint 3000 is my best checkpoint (according to my metric_for_best_model), so I'm assuming that I have to have both the last checkpoint AND the best checkpoint available in the output dir, for this to work? If so, the documentation for hub_strategy is mistaken, because it stated:

"checkpoint": like "every_save" but the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to resume training easily with trainer.train(resume_from_checkpoint="last-checkpoint").

Which is wrong.

Am I missing something?

This seems like a separate issue. Please open another one with a minimal reproducible example. Currently, the given details aren't enough for us to reproduce this.

@omermazig
Copy link
Author

ValueError: 'videomae-finetuned/checkpoint-3000' is not in list

Checkpoint 3000 is my best checkpoint (according to my metric_for_best_model), so I'm assuming that I have to have both the last checkpoint AND the best checkpoint available in the output dir, for this to work? If so, the documentation for hub_strategy is mistaken, because it stated:
"checkpoint": like "every_save" but the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to resume training easily with trainer.train(resume_from_checkpoint="last-checkpoint").
Which is wrong.
Am I missing something?

This seems like a separate issue. Please open another one with a minimal reproducible example. Currently, the given details aren't enough for us to reproduce this.

Done here:

#27728

I'm closing this issue because ignore_data_skip=True works for me. If someone doesn't find @pacman100 solution workable please reopen this.

@Ubadub
Copy link
Contributor

Ubadub commented Nov 27, 2023

Hello, your iterable dataset should reiterate when reaching the end if the number of steps> number of samples in the iterable dataset.

I'm sorry but this response doesn't make sense and this issue should not be marked as closed so prematurely. The max number of steps passed to the trainer indicates the maximum number of steps over the entire training run. However, when resuming from checkpoint, the run will stop training if the number of steps is less than the number of samples within a single epoch.

To clarify, you are technically correct that "your iterable dataset should reiterate when reaching the end." However, the Trainer and/or IterableDataset classes should handle this- as they already do when not resuming from checkpoint.

It is unclear why resuming from checkpoint causes them to fail to handle this. When not resuming from checkpoint, the training logic is as you expect: if you run out of samples in the current epoch but haven't reached max steps yet, you just start a new epoch until you do reach max steps.

@omermazig omermazig reopened this Nov 27, 2023
@Ubadub
Copy link
Contributor

Ubadub commented Nov 27, 2023

I'm closing this issue because ignore_data_skip=True works for me. If someone doesn't find @pacman100 solution workable please reopen this.

This issue should not be closed because ignore_data_skip=True is not a real solution to this problem as it changes the logic of the training run and eliminates reproducibility.

I feel perhaps there is some fundamental miscommunication happening here because this seems very transparently obvious to me that this is not how this should work.

I have had identical runs where:

  1. one run got preempted before reaching the second epoch (with save_strategy=checkpoint), and therefore resumed from the first epoch checkpoint, before erroring
  2. one run that continued past the second epoch

This is clear, incontrovertible evidence of a bug since it indicates different training logic is happening depending on whether resume_from_checkpoint is True or False.

Let me put it another way: If you agree that

number of training steps = (desired number of epochs) * (number of samples)/(batch size)

which implies

number of desired training steps =  (number of samples) * (desired number of epochs)/(batch size)

then do you agree that number of desired training steps > number of samples if and only if desired number of epochs > batch size?

Which is going to be true in many instances? And yet this is precisely the condition upon which the error triggers, at least according to the error message.

When not resuming from checkpoint, this simple mathematical fact poses no problem. It is only when resuming from checkpoint that for some reason this inequality poses a conundrum, and that is what makes no sense.

@Ubadub
Copy link
Contributor

Ubadub commented Nov 27, 2023

I am just now realizing that the example dataset @pacman100 provides as a working solution is a fully written Dataset class. However, the point of this issue is that it happens with a Dataset class provided by HuggingFace itself, namely, the IterableDataset class. The expectation is that HF datasets should "just work" with the HF Trainer; especially if this incompatibility is not identified in the docs, which, AFAIK, it is not. Perhaps I am incorrect on the latter count.

@pacman100
Copy link
Contributor

Hello @Ubadub, please provide a minimal reproducible example wrt this along with the related config, the launch command and the versions of the libraries.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 1, 2024
@muupan
Copy link

muupan commented Jan 11, 2024

@pacman100 Hello, I am also facing the same issue as @Ubadub is reporting. Here is my code to reproduce the issue:

import os
import shutil
import transformers
import datasets


if os.path.exists("./output"):
    shutil.rmtree("./output")


def my_generator():
    for i in range(10):
        yield {"input_ids": [1000], "labels": [1000]}


# This dataset yields 10 examples only, but let's set max_steps=20.
dataset = datasets.IterableDataset.from_generator(my_generator)
model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
args = transformers.TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=1,
    max_steps=20,
    save_steps=10,
    report_to="none",
)
trainer = transformers.Trainer(
    model=model,
    args=args,
    train_dataset=dataset,
)
trainer.train()
# Trainer runs 20 steps, producing both checkpoint-10 checkpoint-20.
assert os.path.exists("./output/checkpoint-10")
assert os.path.exists("./output/checkpoint-20")

# Now remove checkpoint-20 and resume training from checkpoint-10.
shutil.rmtree("./output/checkpoint-20")
trainer = transformers.Trainer(
    model=model,
    args=args,
    train_dataset=dataset,
)
trainer.train(resume_from_checkpoint=True)
# This time, trainer does nothing. checkpoint-20 is not produced.
assert os.path.exists("./output/checkpoint-10")
assert not os.path.exists("./output/checkpoint-20")

output:

{'train_runtime': 20.8257, 'train_samples_per_second': 0.96, 'train_steps_per_second': 0.96, 'train_loss': 0.0, 'epoch': 1.5}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:20<00:00,  1.04s/it]
There were missing keys in the checkpoint model loaded: ['lm_head.weight'].
  0%|                                                                                                                                                                            | 0/20 [00:00<?, ?it/s]
There seems to be not a single sample in your epoch_iterator, stopping training at step 10! This is expected if you're using an IterableDataset and set num_steps (20) higher than the number of available samples.
{'train_runtime': 0.0044, 'train_samples_per_second': 4513.401, 'train_steps_per_second': 4513.401, 'train_loss': 0.0, 'epoch': 0.5}
  0%|

When not resuming, Trainer runs until 20 steps. When resuming from a checkpoint, it tries to run until 10 steps. This seems inconsistent.

As discussed in #26635, I think the correct behavior suggested by the current documentation of max_steps should be Trainer reiterating the dataset until 20 steps are executed even if the dataset is finite and smaller than 20.

max_steps (`int`, *optional*, defaults to -1):
If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
`max_steps` is reached.

I'm using Python v3.10.12, transformers==4.36.2, datasets==2.16.1, accelerate==0.26.0, torch==2.1.2.

@ArthurZucker ArthurZucker reopened this Jan 11, 2024
@Ubadub
Copy link
Contributor

Ubadub commented Jan 11, 2024

@muupan Thank you for the minimal example, I had a lot on my plate and was unable to do that so I ended up just scrapping the use of this functionality altogether, but this introduced its own complications so I would really appreciate a fix for this.

@huggingface huggingface deleted a comment from github-actions bot Feb 5, 2024
@Ubadub
Copy link
Contributor

Ubadub commented Feb 5, 2024

Confirming this issue should not be marked stale and still requires addressing.

@huggingface huggingface deleted a comment from github-actions bot Mar 1, 2024
@huggingface huggingface deleted a comment from github-actions bot Mar 26, 2024
@huggingface huggingface deleted a comment from github-actions bot Apr 22, 2024
@amyeroberts
Copy link
Collaborator

Gentle ping @muellerzr @pacman100

@huggingface huggingface deleted a comment from github-actions bot May 17, 2024
@huggingface huggingface deleted a comment from github-actions bot Jun 11, 2024
@amyeroberts
Copy link
Collaborator

cc @muellerzr @SunMarc

@huggingface huggingface deleted a comment from github-actions bot Jul 6, 2024
@huggingface huggingface deleted a comment from github-actions bot Aug 1, 2024
@amyeroberts
Copy link
Collaborator

Gentle ping @muellerzr @SunMarc

@R2D2oid
Copy link

R2D2oid commented Aug 10, 2024

Does anyone have a workaround or a solution for this yet?

@muupan
Copy link

muupan commented Sep 14, 2024

My current workaround is making the dataset yield samples infinitely. In my example code, if I replace the definition of my_generator with

def my_generator():
    while True:
        for i in range(10):
            yield {"input_ids": [1000], "labels": [1000]}

, the resumed training continues until 20th steps correctly. However, this workaround has a drawback: 1 epoch now corresponds to max_steps steps, so you cannot count epochs in the original sense.

@muupan
Copy link

muupan commented Sep 17, 2024

I implemented a solution and opened a PR: #33544

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker ArthurZucker reopened this Oct 22, 2024
@SunMarc SunMarc reopened this Nov 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants