Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roll-forward with fixes: Fix interaction between scheduler.step() and gradient accumulation steps, refactor schedulers to use LambdaLR, and add cosine annealing LR scheduler as a decay method. #3555

Merged
merged 4 commits into from
Aug 29, 2023

Conversation

justinxzhao
Copy link
Contributor

@justinxzhao justinxzhao commented Aug 29, 2023

Summarizing findings with the LR scheduler, with proposed fixes:

1. Reliance on default argument values with a broken resetting mechanism.

The main issue with the cosine decay LR PR is the incorporation of default values in the construction of the scheduler.

  • step_info.num_warmup_steps is set incorrectly due to arbitrary default argument values (steps_per_checkpoint=1000, total_steps=10000) used in the LRScheduler’s constructor.
  • step_info.num_warmup_steps is set to correct values only on the first call to scheduler.reset() which happens at the start of the primary train loop.
  • The original PR included a refactoring to move from an in-line lambda fn (a), to a formally constructed LambdaLR object (b).
    • (a) uses step_info.num_warmup_steps by reference. The subsequent call to scheduler.reset() (which updates step_info.num_warmup_steps) does use the correct number of warmup steps.
    • (b) uses step_info.num_warmup_steps by value. The subsequent call to scheduler.reset() (which updates step_info.num_warmup_steps) doesn’t update/reconstruct LambdaLR.

This is why we haven’t seen any issues wrt regular training even though the default arguments have been there for at least 8 months.

Proposal:

  • Remove default argument values for LRScheduler.
  • Explicitly construct a dummy scheduler (with steps_per_checkpoint=0, total_steps=0) at Trainer init time.
    • Having some scheduler object in the state dict seems necessary for checkpoint loading code paths.
    • It becomes totally obvious if the dummy scheduler is incidentally being used during training.
  • Once we have the correct values steps_per_checkpoint and total_steps at the start of the train loop, instead of calling scheduler.reset(), construct a new LRScheduler object with these values.
    • LambdaLR will be initialized with the correct step_info.num_warmup_steps.

2. gradient_accumulation_steps on scheduler.step()'s control flow

This has been an issue independent of the cosine decay PR.

  • This OR condition decides when we call scheduler.step(): step % gradient_accumulation_steps == 0 or step == is_checkpoint_step.
  • Each time we call step(), torch scheduler’s internal current step is incremented by 1.
  • Torch schedulers have no visibility into when we call .step() on them.
  • Folding this condition into LRScheduler requires pre-calculating when the scheduler.step() is called so that we have an accurate mapping from scheduler’s current step and the actual occurred number of training steps.
    The formula for this pre-calculation is possible, but complex, as it involves consolidating both gradient_accumulation_steps and steps_per_checkpoint, and passing this map to each scheduler.

Proposal

  • Synchronize schedule.step() and training steps by simply calling schedule.step() every training step.
  • The number of training steps and the number times scheduler.step() has been called is synchronized.
  • LR warmup and decay are continuous and the LR will change a little bit each training step.
    • This should make for slight difference in gradient averaging to be entirely insignificant over the full training run.

@github-actions
Copy link

Unit Test Results

       6 files  ±       0         6 suites  ±0   1h 49m 42s ⏱️ + 27m 54s
2 822 tests +2 788  2 805 ✔️ +2 776  12 💤 +  7    5 +  5 
8 431 runs  +8 343  8 385 ✔️ +8 313  31 💤 +15  15 +15 

For more details on these failures, see this check.

Results for commit 0141421. ± Comparison against base commit f34c272.

This pull request removes 4 and adds 2792 tests. Note that renamed tests count towards both.
tests.regression_tests.model.test_old_models ‑ test_model_loaded_from_old_config_prediction_works
tests.regression_tests.model.test_old_models ‑ test_predict_deprecated_model[respiratory]
tests.regression_tests.model.test_old_models ‑ test_predict_deprecated_model[titanic]
tests.regression_tests.model.test_old_models ‑ test_predict_deprecated_model[twitter_bots]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_image_augmentation[augmentation_pipeline_ops0]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_image_augmentation[augmentation_pipeline_ops1]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_image_augmentation[augmentation_pipeline_ops2]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_invalid_augmentation_parameters[None]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_invalid_augmentation_parameters[augmentation_pipeline_ops1]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_invalid_augmentation_parameters[augmentation_pipeline_ops2]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_invalid_augmentation_parameters[augmentation_pipeline_ops4]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_invalid_augmentation_parameters[random_horizontal_flip]
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_load_model_with_augmentation_pipeline
tests.ludwig.augmentation.test_augmentation_pipeline ‑ test_local_model_training_with_augmentation_pipeline[preprocessing0-encoder0-False]
…

@justinxzhao justinxzhao marked this pull request as ready for review August 29, 2023 22:02
@justinxzhao justinxzhao merged commit fed9b82 into master Aug 29, 2023
@justinxzhao justinxzhao deleted the ff_lr branch August 29, 2023 22:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant