Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Decoder-Only Transformer #4329

Merged
merged 29 commits into from
May 4, 2022
Merged

Decoder-Only Transformer #4329

merged 29 commits into from
May 4, 2022

Conversation

spencerp
Copy link
Contributor

@spencerp spencerp commented Jan 27, 2022

Supporting decoder-only transformer training:

parlai train_model --model transformer/decoder

Summary of changes

  • New TransformerDecoderLayer without encoder attention
  • New TransformerDecoder that concatenates input and "encoder state"
    - Needed to pass the dictionary to TransformerDecoder to properly do concatenation
    • Right-pad the label, left-pad the context
  • New PassThroughEncoder, similar to IdentityLayer but compatible with TransformerEncoder API
    • There are many assumptions about there being an "encoder" in TorchGeneratorAgent, so the path of most code reuse led to using a dummy encoder to satisfy those assumptions
  • New DecoderAgent to override build_model
  • Extract label from scores to calculate loss (done at end of TransformerDecoderOnly.forward)
  • Only use last query_len rows from incremental attention to account for incremental step from context to first generated token
  • Added DecoderIncrState and DecoderLayerIncrState type aliases
  • Added unit tests

Summary of structural changes in decoder.py
decoder_only_refac

Testing steps

python -m pytest tests/test_transformers.py

Screen Shot 2022-02-08 at 12 37 03

python -m pytest tests/test_light_whoami.py

Screen Shot 2022-02-08 at 12 12 04

Also ran some small training runs locally to sanity check:

TF_ARGS="--embedding-size 128 --ffn-size 512 --batchsize 16 --eval-batchsize 16 --model-parallel true --variant prelayernorm --n-heads 8 --n-positions 512 --activation gelu --text-truncate 256 --label-truncate 128 -lr 7e-05 --lr-scheduler invsqrt --optimizer adam --warmup_updates 1000 -vp 10 -vmt ppl -vmm min --load-from-checkpoint false -vstep 500 --validation-max-exs 1000 -tstep 150000 --log-every-n-secs 30 --update-freq 1 --dynamic-batching full"
parlai train_model -t convai2 --model transformer/decoder --model-file /tmp/test_decoder_only_model --n-layers 2 $TF_ARGS
parlai train_model -t convai2 --model transformer/generator --model-file /tmp/test_enc_decoder_model --n-encoder-layers 1 --n-decoder-layers 1 $TF_ARGS

Benchmark Comparison with Encoder-Decoder
Both trained to 10k steps with 120M parameters (8 layers) on 2x Quadro GP100 GPUs. The encoder/decoder is closer to 140M parameters due to the cross-attention.

TF_ARGS="--embedding-size 1024 --ffn-size 4096 --batchsize 16 --eval-batchsize 16 --model-parallel true --variant prelayernorm --n-heads 16 --n-positions 512 --activation gelu --text-truncate 256 --label-truncate 128 -lr 7e-05 --lr-scheduler invsqrt --optimizer adam --warmup_updates 1000 -vp 10 -vmt ppl -vmm min --load-from-checkpoint false -vstep 1000 --validation-max-exs 1000 -tstep 10000 --log-every-n-secs 30 --update-freq 1 --dynamic-batching full --fp16 true"
parlai train_model -t internal:new_reddit --model transformer/decoder --model-file /tmp/test_decoder_only_model --n-layers 8 $TF_ARGS
...
13:17:37 | time:6960s total_exs:433912 total_steps:9950 epochs:0.00 time_left:35s
    clen  clip  ctpb  ctps  ctrunc  ctrunclen  exps  exs  fp16_loss_scalar  gnorm  gpu_mem  llen  loss        lr  ltpb  ltps  ltrunc  ltrunclen   ppl  token_acc  token_em  total_train_updates  tpb  tps   ups
   58.72     1  2603  4957  .01423      .8327 85.63 2248            131072  1.711    .2126 26.09 4.036 2.219e-05  1173  2234       0          0 56.59      .2843         0                 9950 3776 7191 1.905

13:18:04 | Stopping from Maximum LR steps
parlai train_model -t internal:new_reddit --model transformer/generator --model-file /tmp/test_enc_decoder_model --n-encoder-layers 4 --n-decoder-layers 4 $TF_ARGS
...
10:52:31 | time:5367s total_exs:431792 total_steps:9950 epochs:0.00 time_left:27s
    clen  clip  ctpb  ctps  ctrunc  ctrunclen  exps  exs  fp16_loss_scalar  gnorm  gpu_mem  llen  loss        lr  ltpb  ltps  ltrunc  ltrunclen   ppl  token_acc  token_em  total_train_updates  tpb   tps  ups
   64.57     1  2711  6939  .02041      1.706 110.4 2156            131072  1.633    .1838 28.03 4.066 2.219e-05  1208  3094       0          0 58.33      .2803         0                 9950 3919 10033 2.56

10:52:51 | Stopping from Maximum LR steps

@stephenroller stephenroller self-requested a review March 1, 2022 16:46
Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah looks pretty good to me. Do you mind giving us a benchmark on -t internal:new_reddit with 125M params? trained to say, 10k steps?

parlai/agents/transformer/decoder_only.py Outdated Show resolved Hide resolved
parlai/agents/transformer/decoder_only.py Outdated Show resolved Hide resolved
parlai/agents/transformer/decoder_only.py Outdated Show resolved Hide resolved
Override of ``TorchAgent.build_model``.
"""
assert (
self.opt['n_encoder_layers'] == -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wishlist: this would be a really great opportunity to finally implement the parser.remove_arg functionality we've desired for a long time

parlai/agents/transformer/modules/decoder.py Outdated Show resolved Hide resolved
parlai/agents/transformer/modules/decoder.py Outdated Show resolved Hide resolved
@@ -1393,3 +1393,10 @@ def error(self, message):
self.print_help()
_sys.stderr.write('\nParse Error: %s\n' % message)
_sys.exit(2)


def default(val, default):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(y)

@stephenroller
Copy link
Contributor

Alright it's a little slower but good enough for today. Let Jason know it's done!

@spencerp
Copy link
Contributor Author

spencerp commented Mar 5, 2022

Hm, there's a distillation test failing which I don't understand yet... dec_hid_loss and enc_dec_attn_loss don't match what the test expects which is suspicious, so I'm going to wait to merge until I understand that better.

@klshuster klshuster mentioned this pull request Mar 25, 2022
@klshuster
Copy link
Contributor

cc @EricMichaelSmith: there are two distillation tests failing due to this PR, do you have any idea why that might be the case? it looks like very small (but consistent) differences in testing (that is, these tests do not seem to be failing on main but consistently fail on this branch)

@stephenroller
Copy link
Contributor

Let's just let this sit a while. The regression failure is a little concerning. It's enough to make me worry the code has changed in a subtle way, but close enough that it's tempting to just --force-regen and forget about it.

@EricMichaelSmith
Copy link
Contributor

Let's just let this sit a while. The regression failure is a little concerning. It's enough to make me worry the code has changed in a subtle way, but close enough that it's tempting to just --force-regen and forget about it.

Hmm seconded - yeah, @klshuster I'm not sure offhand, but perhaps this does indicate that the forward pass has indeed changed in some subtle way?

@github-actions
Copy link

This PR has not had activity in 30 days. Closing due to staleness.

@github-actions github-actions bot added the stale label Apr 29, 2022
@stephenroller
Copy link
Contributor

@spencerp
Copy link
Contributor Author

I will resuscitate it next week :)

@github-actions github-actions bot removed the stale label Apr 30, 2022
@spencerp
Copy link
Contributor Author

spencerp commented May 3, 2022

Finally got back to this today and figured out the issue!

Apparently the distillation code is sensitive to the order in which modules are initialized 😔. I made a branch to demonstrate the minimal change needed to repro the test break: #4526

To my knowledge, module initialization order has no implications on training dynamics/model performance. So I'm inclined to call this a distillation bug and outside the scope of this PR. Anyone know differently? @EricMichaelSmith

If I don't hear any objections by Thursday I'll merge as-is with the broken test and file an issue to fix the bug in the distillation code.

@spencerp
Copy link
Contributor Author

spencerp commented May 3, 2022

Chatted with @EricMichaelSmith offline. He pointed out that this could be a result of a different order of random operations done during module initialization. I was able to confirm that's the problem, details in #4526

So I'm just going to update the test numbers and merge.

@spencerp spencerp merged commit ecdfbd0 into main May 4, 2022
@spencerp spencerp deleted the decoder-only branch May 4, 2022 00:27
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants