Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better checking of data returned from training_step #1256

Closed
jeremyjordan opened this issue Mar 27, 2020 · 12 comments
Closed

better checking of data returned from training_step #1256

jeremyjordan opened this issue Mar 27, 2020 · 12 comments
Assignees
Labels
feature Is an improvement or enhancement good first issue Good for newcomers won't fix This will not be worked on

Comments

@jeremyjordan
Copy link
Contributor

🚀 Feature

let's add more validation checks on what's returned from training_step and provide the user with useful error messages when they're not returning the right values.

Motivation

i feel like i've seen a lot of users confused about what they're supposed to return in training_step and validation_step. additionally, i don't think we document how we treat extra keys as "callback metrics" very well.

Pitch

what you do you think about adding some structure and validation for Trainer's process_output method?

right now, we have expectations about a set of keys {progress_bar, log, loss, hiddens} and assume everything else is a callback metric. however, this is a silent assumption.

we could instead enforce a more rigid structure:

{
  'loss': loss                   # REQUIRED
  'log': {}                         # optional dict
  'progress_bar': {}       # optional dict
  'hiddens': [h0, c0]     # optional collection of tensors
  'metrics': {}                 # optional dict
}

moreover, we can leverage pydantic to do validation automatically and provide useful error message out of the box when data validation fails.

cc @PyTorchLightning/core-contributors

Alternatives

Do nothing, keep things as they are.

Additional context

This would be a backwards incompatible change.

@jeremyjordan jeremyjordan added feature Is an improvement or enhancement help wanted Open to be worked on labels Mar 27, 2020
@Borda Borda added the good first issue Good for newcomers label Mar 27, 2020
@rotalex
Copy link

rotalex commented Mar 27, 2020

I would like to work on this.

@Borda
Copy link
Member

Borda commented Mar 27, 2020

@rotalex cool, looking forward seeing a PR from you :]

@jeremyjordan
Copy link
Contributor Author

@Borda given that this proposal is backwards compatible, i think we should get more core contributors to weigh in on the proposed design before moving forward and implementing it.

one thing that is still giving me tension is the fact that there's a lot of overlap between log, progress_bar, and metrics. progress_bar almost always consists of a subset of log, and metrics (or as they currently stand, arbitrary keys) are typically used to store temporary values to be collated and logged at the end of an epoch. i think there's room for improvement here.

@Borda
Copy link
Member

Borda commented Mar 28, 2020

@jeremyjordan good point! we had recently an issue about "why there are two dicts - one for the progress bar and the general while it has the same values" so some simplification or more structured way would be welcome...
cc: @PyTorchLightning/core-contributors ^^

@williamFalcon
Copy link
Contributor

log and progress_bar were separated a while back because people wanted to log different things that they didn't want in the progress bar

@williamFalcon
Copy link
Contributor

i don't really know what metrics is.

@jeremyjordan
Copy link
Contributor Author

the usage for log and progress_bar keys is clear. however, if you look at the method for process output you'll see

# ---------------
# EXTRACT CALLBACK KEYS
# ---------------
# all keys not progress_bar or log are candidates for callbacks
callback_metrics = {}
for k, v in output.items():
    if k not in ['progress_bar', 'log', 'hiddens']:
        callback_metrics[k] = v

if train and (self.use_dp or self.use_ddp2):
    num_gpus = self.num_gpus
    callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)

for k, v in callback_metrics.items():
    if isinstance(v, torch.Tensor):
        callback_metrics[k] = v.item()

all keys not progress_bar or log are candidates for callbacks

as far as i know, this isn't documented anywhere.

if you look in the documentation, however, you will see references to keys which are not included in the set of {loss, log, progress_bar} but the only hint about how to use them is through the examples we provided (eg. val_loss below)

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': val_loss_mean}

now that i think about this further, a better solution might be to add more detail to the documentation about how we collect outputs from training steps and expose the outputs at the end of epochs to make this more clear. furthermore, we should document that if you're returning a torch tensor, we expect that it is a scalar value.

the second question is whether or not we want to do more explicit validation of data returned by the user. the motivation for this github issue is #1236 where we could help the user more quickly track down the source of an error. imagine if the error raised to the user was instead ValidationError: Cannot reduce key 'val_loss' to a scalar.

@Borda
Copy link
Member

Borda commented Mar 29, 2020

I would be in favour of more rigid structure, also may check #1277

@gabisurita
Copy link

gabisurita commented Apr 30, 2020

Shouldn't we favor the return type to be a strong type? I've always wondered why the step return type is not a dataclass or named tuple where loss is a required argument. We could keep the flexibility using some metadata dict argument.

@williamFalcon
Copy link
Contributor

williamFalcon commented Apr 30, 2020

i wouldn't mind stronger typing, but i don't want to start adding APIs to remember?
although it might be simpler to remember the structured type instead of the possible keys?

@tullie @ashwinb @Darktex thoughts?

A potential way to add structure is (something like this):

def training_step(...):
      output = pl.StepResult(loss=loss, logs=logs, progress_bar=progress_bar)
      return output

I guess the only this helps with is that the user doesn't have to remember what the keys are in the dict?

Pro

Removes confusion with what keys do what in the return

Con

Adds an API users have to remember
(although you could argue that remembering to put "loss" in a dict is just as bad)

@tullie
Copy link
Contributor

tullie commented Apr 30, 2020

My preference would be a StepResult class. I think it's the best way to document what the outputs should be. We can still support returning a dictionary and just build the StepResult class from the dict on the trainer side anyway. It'd be great if we could create a unified way for callbacks to specify which arguments are required in the StepResult too.

The progress_bar, log overlap that @jeremyjordan brought up isn't idea. I'd love to hear how others think we should address this? The best I could come up with is specifying the desired keys for log and progress bar somewhere in LightningModule init or as a trainer callback argument. The user would then just put all result values in the step result dictionary and the specified keys would be found for the respective outputs (log and/or progress bar).

@stale
Copy link

stale bot commented Jun 30, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jun 30, 2020
@stale stale bot closed this as completed Jul 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement good first issue Good for newcomers won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

6 participants