Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous fixes to the x-transformers implementation #79

Merged
merged 15 commits into from
Nov 4, 2024

Conversation

Waino
Copy link
Collaborator

@Waino Waino commented Oct 21, 2024

  • Validation no longer crashes (transposes were missing).
  • A distributed component covering the parameters in the TransformerWrapper object, most notably to_logits.
  • Arguments of TransformerWrapper can be set through the config file.
  • A fix to the content of state dicts, avoiding duplicate storage of some parameters.
  • Removal of some obsolete opts.
  • Correctly handle stats both with and without accuracy computation (type of initial value is inferred from preceding stats object).

Waino added 9 commits October 7, 2024 12:15
Skip backward if loss is NaN.
Stop training if enough batches are skipped.
The default value must be either zero or None, depending on whether
accuracy is reported or not.
Parameters in the TransformerWrapper, e.g. to_logits, need their own
distributed component and optimizer.
The adapter injection code was causing parameter duplication.

Another issue: to normalize or not to normalize?
We compute a normalization based on either tokens or sents, but never
apply it. The effect can be compensated for using the learning rate, as
long as batches are approximately the same size. Too high learning rates
lead to gradient clipping, which is extra detrimental because each
component is individually clipped.

Clipping deterministically requires one of the following:
- access to gradients for all parameters of the entire model (infeasible)
- component local clipping (current approach)
- communicating a clipping factor across devices (maybe we should do
  this?)
Copy link
Collaborator

@TimotheeMickus TimotheeMickus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mostly have nitpickery of the docstring variety.

mammoth/model_builder.py Show resolved Hide resolved
mammoth/distributed/components.py Show resolved Hide resolved
mammoth/distributed/components.py Outdated Show resolved Hide resolved
mammoth/distributed/components.py Show resolved Hide resolved
mammoth/distributed/components.py Show resolved Hide resolved
mammoth/opts.py Outdated Show resolved Hide resolved
mammoth/opts.py Outdated Show resolved Hide resolved
src,
decoder_input,
src_lengths,
rearrange(src, 't b 1 -> b t'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a perfect world we would just normalize the tensors to batch seq dim across the lib, but this world isn't perfect.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. #80

@@ -476,6 +493,9 @@ def _gradient_accumulation(

try:
if loss is not None:
if torch.isnan(loss):
raise Exception('Loss blowout')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we/should we type this?

Copy link
Collaborator Author

@Waino Waino Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think typed exceptions are mainly useful for recoverable problems that you may want to catch somewhere higher up. Blowing out of the training loop is not going to be recoverable.

The typed exception would be useful to separate NaN loss from OOM and other exceptions from backward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I express a strong disagreement.

  1. You do catch it (and everything else) on line 519.
  2. If you want a clean break, then make it a break / branching out cond. If you want a weird stack jump behavior, make it clear to the other devs you're doing a custom hack around a potential problem
  3. I would also argue typing helps other devs: e.g., assume my dev crashes because of a typo producing a out of bounds, not being caught by this try/except here is bound to make my stack less of pain

The minimal thing is to just make it a runtime so as to avoid cases as in 3, but i wouldn't be against finding something more exotic / custom, that way you could except ExoticError to deal with nan blowouts and except RuntimeError for cuda ooms

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are absolutely right.

I already forgot what I was trying to do here. The catch-everything-and-retry-forever was already there from before, and I tried to make it less bad, when I should have just removed it.

I'll remove the try/except, and the --max_nan_batches opt.

@@ -496,10 +516,13 @@ def _gradient_accumulation(
total_stats.update(batch_stats)
report_stats.update(batch_stats)
report_stats.update_task_loss(batch_stats.loss, metadata)

except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is also meant to to catch cuda oom's, isn't it? in which case your nan counter isn't valid?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's right. The current code maybe allows recovering from a single slightly too large batch, by catching the OOM. This is quite unlikely, though, and wasn't intentional.

We could have a typed exception that increases the nan counter, and let everything else through directly. That would cause OOMs to be instantly fatal.

Waino added 6 commits November 4, 2024 10:25
The old transformer encoder classes are still needed for the obsolete
attention bridge. Removal is pending.
This makes the "all" language hack easier to use.
When using this hack to achieve shared embeddings, it is no longer
possible to use src_lang and tgt_lang to distinguish tasks from each
other (because all tasks are just "all-all").
The iterate_tasks.py helper also supports using the task id as a
template variable. Having the real src-tgt as the task id, without any
useless train_ prefix makes file naming less painful.
@Waino Waino force-pushed the feat/model_blowout branch from 223bed4 to c4ff0c1 Compare November 4, 2024 12:03
@Waino Waino merged commit dd4e1ff into main Nov 4, 2024
2 checks passed
@Waino Waino deleted the feat/model_blowout branch November 4, 2024 13:43
@Waino Waino mentioned this pull request Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants