Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accessing performance of model in progress bar #4326

Closed
swethmandava opened this issue Oct 23, 2020 · 8 comments · Fixed by #4369
Closed

Accessing performance of model in progress bar #4326

swethmandava opened this issue Oct 23, 2020 · 8 comments · Fixed by #4369
Assignees
Labels
bug Something isn't working help wanted Open to be worked on
Milestone

Comments

@swethmandava
Copy link
Contributor

When I run training, I see progress bar indicate iterations/sec. How can I access it? I wrote a simple hook:

class PerfCallback(ProgressBar):

    def __init__(self):
        super().__init__()  # don't forget this :)

    def on_train_start(self, trainer, pl_module):
        super().on_train_start(trainer, pl_module)
        self.total_runtime = 0
        self.unpadded_tokens = 0
        self.all_tokens = 0
        self.total_steps = 0

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        self.t0 = time.time()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        runtime = time.time() - self.t0
        self.total_runtime += runtime
        print("on batch end runtime=%.2f, it/s = %.2f" %(runtime, 1/runtime))

However, my prints indicate ~1.7it/s but the progress bar sows 6.07s/it.

In my training_step, I also return some stats (batch size, number of tokens) in logs by returning the following:

{"loss": loss_tensors[0], "log":logs, "progress_bar":{"global_step":self.global_step}}

for more relevant perf metrics. However, in my callback. print(outputs) shows an empty list. Anything I'm missing?

@swethmandava swethmandava added the question Further information is requested label Oct 23, 2020
@rohitgr7
Copy link
Contributor

I believe these outputs are tracked only if you implement training_epoch_end. Either way between training_step and on_train_batch_end, only optimization step occurs, do I don't think they affect the outputs in any way. You can do whatever you want with the outputs in the training_step itself.

my prints indicate ~1.7it/s but the progress bar sows 6.07s/it.

After on_train_batch_end logging is triggered so maybe that might be the reason for this difference.

@awaelchli
Copy link
Contributor

@rohitgr7 Looking at that code you linked, I saw
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
is sending the epoch_end outputs to the callback hook, but shouldn't that just be the output of training_step in each step?
This looks wrong to me. Is it a bug?

@rohitgr7
Copy link
Contributor

@awaelchli epoch_end_outputs is an empty list unless you override the training_epoch_end.
https://github.com/PyTorchLightning/pytorch-lightning/blob/66704e685d4975ed8d0a680ba7090bf5acd14d73/pytorch_lightning/trainer/training_loop.py#L923-L924

Also, epoch_end_outputs is just the output of training_step, it's just a mislearning variable name maybe.

@awaelchli
Copy link
Contributor

Yes, it's a misleading name but I also think the behaviour should be different.
The callback method should receive directly the output of training_step, I believe.
There is a comment in the code about memory build up, but I don't see the issue if we just pass the output to the callback hook, we just keep this one reference alive until the callback hook is finished.

@rohitgr7
Copy link
Contributor

Agreed 👍, should we allowed irrespective of whether training_epoch_end is implemented or not.

@awaelchli
Copy link
Contributor

Agreed 👍, should we allowed irrespective of whether training_epoch_end is implemented or not.

For on_train_batch_end, yes certainly.
I think this simply got confused with the callback hook named on_train_epoch_end, which certainly should get all outputs in a list.

@swethmandava
Copy link
Contributor Author

Thanks @rohitgr7 @awaelchli - that explains empty outputs. Should I send a PR?

After on_train_batch_end logging is triggered so maybe that might be the reason for this difference.

Where is progress bar iterations/sec coming from and when is it triggered? If I want to measure training performance, is measuring the time between on_train_batch_start and on_train_batch_end accurate?

@rohitgr7
Copy link
Contributor

Should I send a PR?

yes please go ahead.

Where is progress bar iterations/sec coming from and when is it triggered?

it uses tqdm.

time between on_train_batch_start and on_train_batch_end accurate

yes almost, but if you want to check a bit more precisely then I suggest disable logging.

@edenlightning edenlightning added bug Something isn't working help wanted Open to be worked on and removed question Further information is requested labels Oct 27, 2020
@edenlightning edenlightning added this to the 1.0.x milestone Oct 27, 2020
@edenlightning edenlightning modified the milestones: 1.0.x, 1.0.7 Nov 10, 2020
@Borda Borda modified the milestones: 1.0.7, 1.0.x Nov 11, 2020
@edenlightning edenlightning modified the milestones: 1.0.x, 1.0.7 Nov 13, 2020
@Borda Borda modified the milestones: 1.0.7, 1.0.x, 1.1 Nov 13, 2020
@Borda Borda modified the milestones: 1.1, 1.1.x Nov 30, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants