-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Miscellaneous fixes to the x-transformers implementation #79
Conversation
Skip backward if loss is NaN. Stop training if enough batches are skipped.
The default value must be either zero or None, depending on whether accuracy is reported or not.
Parameters in the TransformerWrapper, e.g. to_logits, need their own distributed component and optimizer.
The adapter injection code was causing parameter duplication. Another issue: to normalize or not to normalize? We compute a normalization based on either tokens or sents, but never apply it. The effect can be compensated for using the learning rate, as long as batches are approximately the same size. Too high learning rates lead to gradient clipping, which is extra detrimental because each component is individually clipped. Clipping deterministically requires one of the following: - access to gradients for all parameters of the entire model (infeasible) - component local clipping (current approach) - communicating a clipping factor across devices (maybe we should do this?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i mostly have nitpickery of the docstring variety.
src, | ||
decoder_input, | ||
src_lengths, | ||
rearrange(src, 't b 1 -> b t'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in a perfect world we would just normalize the tensors to batch seq dim across the lib, but this world isn't perfect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. #80
mammoth/trainer.py
Outdated
@@ -476,6 +493,9 @@ def _gradient_accumulation( | |||
|
|||
try: | |||
if loss is not None: | |||
if torch.isnan(loss): | |||
raise Exception('Loss blowout') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we/should we type this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think typed exceptions are mainly useful for recoverable problems that you may want to catch somewhere higher up. Blowing out of the training loop is not going to be recoverable.
The typed exception would be useful to separate NaN loss from OOM and other exceptions from backward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I express a strong disagreement.
- You do catch it (and everything else) on line 519.
- If you want a clean break, then make it a break / branching out cond. If you want a weird stack jump behavior, make it clear to the other devs you're doing a custom hack around a potential problem
- I would also argue typing helps other devs: e.g., assume my dev crashes because of a typo producing a out of bounds, not being caught by this try/except here is bound to make my stack less of pain
The minimal thing is to just make it a runtime so as to avoid cases as in 3, but i wouldn't be against finding something more exotic / custom, that way you could except ExoticError to deal with nan blowouts and except RuntimeError for cuda ooms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are absolutely right.
I already forgot what I was trying to do here. The catch-everything-and-retry-forever was already there from before, and I tried to make it less bad, when I should have just removed it.
I'll remove the try/except, and the --max_nan_batches
opt.
mammoth/trainer.py
Outdated
@@ -496,10 +516,13 @@ def _gradient_accumulation( | |||
total_stats.update(batch_stats) | |||
report_stats.update(batch_stats) | |||
report_stats.update_task_loss(batch_stats.loss, metadata) | |||
|
|||
except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is also meant to to catch cuda oom's, isn't it? in which case your nan counter isn't valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's right. The current code maybe allows recovering from a single slightly too large batch, by catching the OOM. This is quite unlikely, though, and wasn't intentional.
We could have a typed exception that increases the nan counter, and let everything else through directly. That would cause OOMs to be instantly fatal.
The old transformer encoder classes are still needed for the obsolete attention bridge. Removal is pending.
This makes the "all" language hack easier to use. When using this hack to achieve shared embeddings, it is no longer possible to use src_lang and tgt_lang to distinguish tasks from each other (because all tasks are just "all-all"). The iterate_tasks.py helper also supports using the task id as a template variable. Having the real src-tgt as the task id, without any useless train_ prefix makes file naming less painful.
223bed4
to
c4ff0c1
Compare
TransformerWrapper
object, most notablyto_logits
.TransformerWrapper
can be set through the config file.