Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add total training steps as a property to trainer #10760

Closed
rohitgr7 opened this issue Nov 25, 2021 · 21 comments · Fixed by #11599
Closed

Add total training steps as a property to trainer #10760

rohitgr7 opened this issue Nov 25, 2021 · 21 comments · Fixed by #11599
Assignees
Labels
feature Is an improvement or enhancement lr scheduler optimization
Milestone

Comments

@rohitgr7
Copy link
Contributor

rohitgr7 commented Nov 25, 2021

🚀 Feature

See title:
Similar issues: #5449, #10430.. I'll keep linking more.

Motivation

This is a highly requested feature from the lightning community. Total training steps is being used by some of the lr schedulers, especially when using transformer models, but since there are a lot of arguments/flags involved for computing it, it's not easy for a user to create one that can work on all the possible edge cases with no-code change. Also we make updates to some of these flags and accessibility to some of its components (for eg. train_datalaoders in v1.5), so there is a possibility that a custom one created by a user might get outdated soon, and one has to write a new one which is only possible if they are well aware with the codebase internals. But we as core-contributors and maintain it with some tests of course.

Pitch

would like to credit @SeanNaren for helping out :)

@property
def num_training_steps(self) -> int:
    """Total training steps inferred from datamodule and devices."""
    if self.trainer.num_training_batches != float('inf'):
        dataset_size = self.trainer.num_training_batches
    else:
        rank_zero_warn('Requesting dataloader...')
        dataset_size = len(self.trainer._data_connector._train_dataloader_source.dataloader())

    if isinstance(self.trainer.limit_train_batches, int):
        dataset_size = min(dataset_size, self.trainer.limit_train_batches)
    else:
        dataset_size = int(dataset_size * self.trainer.limit_train_batches)

    accelerator_connector = self.trainer._accelerator_connector
    if accelerator_connector.use_ddp2 or accelerator_connector.use_dp:
        effective_devices = 1
    else:
        effective_devices = self.trainer.devices

    effective_devices = effective_devices * self.trainer.num_nodes
    effective_batch_size = self.trainer.accumulate_grad_batches * effective_devices
    max_estimated_steps = math.ceil(dataset_size // effective_batch_size) * self.trainer.max_epochs

    max_estimated_steps = min(max_estimated_steps, self.trainer.max_steps) if self.trainer.max_steps != -1 else max_estimated_steps
    return max_estimated_steps

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda
cc @PyTorchLightning/core-contributors

@rohitgr7 rohitgr7 added feature Is an improvement or enhancement optimization lr scheduler labels Nov 25, 2021
@rohitgr7 rohitgr7 added this to the 1.6 milestone Nov 25, 2021
@tchaton
Copy link
Contributor

tchaton commented Nov 25, 2021

I have definitely seen this question being asked over and over.

So I believe we should do it ! @awaelchli @carmocca @ananthsub ?

@carmocca
Copy link
Contributor

The last time this was discussed, I remember somebody mentioning that the problem with this is that it will not be correct in all circumstances and there's no way for us to know it. A problem of silent correctness.

Also, this property cannot be called from anywhere, the attributes used inside need to be computed first.

Just pointing out possible problems, not saying we shouldn't add it.

@awaelchli
Copy link
Contributor

awaelchli commented Nov 25, 2021

I raised the concern for correctness earlier. However, I do believe it becomes more important to provide such utility because Lightning is getting more complex and harder to understand what's happening under the hood for the regular user. And if users ask for it, we should provide we should provide the most accurate estimation that is possible, then at the same time document what the unknowns are (e.g., accumulation scheduling via callback).

Also, which implementation are the users asking for? The number of training steps, i.e., the number of times the training_step will be called / the size of the "dataloader", OR the number of optimization steps?
Both may be used for learning rate scheduling (?), and if you are converting from PyTorch to Lightning you may have to choose one or the other (?).

@SeanNaren
Copy link
Contributor

Thanks for picking this up @rohitgr7 again, I'm very for this :) I've added both functionality manually into a class for Flash and Transformers and it would be nice to have a single function available through the trainer.

Regarding correctness, I think the cases this function breaks down I detailed above. However in my opinion I think the gains outweigh the issue of correctness. It may be worth putting a disclaimer in the docstring that in certain cases we cannot estimate correctly!

@rohitgr7
Copy link
Contributor Author

Also, this property cannot be called from anywhere, the attributes used inside need to be computed first.

I think except for the dataloader, every other argument will already be there when we initialize the trainer. but yeah dataloader might need a few more things that might not be there for eg. dataset, batch_size... so won't be able to access it inside .setup (which I guess is okay because we don't expect users to do that inside setup).

(e.g., accumulation scheduling via callback).

accumulation via callback will still be available right? since we will have it during Trainer init?

Also, which implementation are the users asking for? The number of training steps, i.e., the number of times the training_step will be called / the size of the "dataloader", OR the number of optimization steps?
Both may be used for learning rate scheduling (?), and if you are converting from PyTorch to Lightning you may have to choose one or the other (?).

its number of optimization steps. We can rename the method to num_optimization_steps if required.

@justusschock
Copy link
Member

@awaelchli

Both may be used for learning rate scheduling (?), and if you are converting from PyTorch to Lightning you may have to choose one or the other (?).

AFAIK for lr scheduling you only use the number of optimization steps, since this is the only thing that may influence the model and the optimizer :) So I think it's fine to go with that and not provide the size of the "dataloader"

@carmocca
Copy link
Contributor

In that case, after one epoch, trainer.global_step == trainer.num_optimization_steps?

@rohitgr7
Copy link
Contributor Author

I think yes.. it should be. any edge case we missed?

@mariomeissner
Copy link

I got TypeError: '_DataLoaderSource' object is not callable when using the function provided here. Running lightning 1.5.3.

@rohitgr7
Copy link
Contributor Author

@mariomeissner in which hook/method did you call this function?

@mariomeissner
Copy link

@rohitgr7 I call it inside configure_optimizers to set up the scheduler!

@awaelchli
Copy link
Contributor

You probably need to change this line
dataset_size = len(self.trainer._data_connector._train_dataloader_source())
to

dataset_size = len(self.trainer._data_connector._train_dataloader_source.dataloader())

(the error TypeError: '_DataLoaderSource' object is not callable pretty much gives it away)

@lukasschmit
Copy link

One thing I noticed with the provided solution, if passing gpus = -1 to the trainer the total number of training steps will end up being negative as effective_devices = self.trainer.devices will be -1 as well

@rohitgr7
Copy link
Contributor Author

One thing I noticed with the provided solution, if passing gpus = -1 to the trainer the total number of training steps will end up being negative as effective_devices = self.trainer.devices will be -1 as well

good catch! I think we should use num_processes instead. Need to check whether it will hold true for TPU and DDP2 or not.

@SeanNaren
Copy link
Contributor

@rohitgr7 can we add this as a property now? Definitely would be a great feature I'm interested in :)

@rohitgr7
Copy link
Contributor Author

@rohitgr7 can we add this as a property now? Definitely would be a great feature I'm interested in :)

yes! just waiting for final approvals.
cc @tchaton @awaelchli @carmocca

@carmocca
Copy link
Contributor

Sounds good to me.

@ananthsub
Copy link
Contributor

  1. from naming alone, num_training_steps is ambiguous. for new users, it's unclear whether this means the total per process, total per process per epoch, total globally, or the progress counter for number of steps completed. it's especially confusing since there's an existing num_training_batches property which has completely different semantics. How do you plan to make this easily understandable?
  2. from the code snippet, this is computed globally across all processes. how do you plan to handle potential uneven end of data cases (in case the dataloader returns different lengths across ranks?)
  3. gradient accumulation scheduling can vary per epoch, so the code snippet only works if the accumulation steps is constant across all epochs. you'd need additional handling to get the accumulation steps per epoch
  4. what should this return for infinite training or time-based training when the step count isn't pre-specified?
  5. do you anticipate doing this for other running stages too, like validation/testing/prediction?

@rohitgr7
Copy link
Contributor Author

  1. We can think of a better name, total_training_steps maybe. Shouldn't be a blocker I guess.
  2. AFAIK, Uneven inputs is supported during prediction only? for training/evaluation we still use DistributedSampler.
  3. self.trainer.accumulate_grad_batches updates at every epoch. Now the value will depend upon when users call it.
  4. We can return something, maybe inf, or raise a warning/error. But would a user need it in that case? Can they predict it on their own in such a case?
  5. yes, but I don't think we have any use-case yet. If there's any, then we will.

@carmocca
Copy link
Contributor

For the name, I would suggest something like expected_training_steps to differentiate it from count variables that get updated as training progresses

@tchaton
Copy link
Contributor

tchaton commented Jan 21, 2022

I believe estimated_num_steps and estimated_training_num_steps might be less confusing for the users.

@carmocca carmocca moved this to In Progress in Frameworks Planning Feb 16, 2022
@carmocca carmocca moved this from In Progress to In Review in Frameworks Planning Feb 23, 2022
Repository owner moved this from In Review to Done in Frameworks Planning Feb 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement lr scheduler optimization
Projects
No open projects
Status: Done
Development

Successfully merging a pull request may close this issue.

9 participants