-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah looks pretty good to me. Do you mind giving us a benchmark on -t internal:new_reddit with 125M params? trained to say, 10k steps?
Override of ``TorchAgent.build_model``. | ||
""" | ||
assert ( | ||
self.opt['n_encoder_layers'] == -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wishlist: this would be a really great opportunity to finally implement the parser.remove_arg
functionality we've desired for a long time
@@ -1393,3 +1393,10 @@ def error(self, message): | |||
self.print_help() | |||
_sys.stderr.write('\nParse Error: %s\n' % message) | |||
_sys.exit(2) | |||
|
|||
|
|||
def default(val, default): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(y)
Alright it's a little slower but good enough for today. Let Jason know it's done! |
Hm, there's a distillation test failing which I don't understand yet... |
cc @EricMichaelSmith: there are two distillation tests failing due to this PR, do you have any idea why that might be the case? it looks like very small (but consistent) differences in testing (that is, these tests do not seem to be failing on main but consistently fail on this branch) |
Let's just let this sit a while. The regression failure is a little concerning. It's enough to make me worry the code has changed in a subtle way, but close enough that it's tempting to just --force-regen and forget about it. |
Hmm seconded - yeah, @klshuster I'm not sure offhand, but perhaps this does indicate that the forward pass has indeed changed in some subtle way? |
This PR has not had activity in 30 days. Closing due to staleness. |
I will resuscitate it next week :) |
Finally got back to this today and figured out the issue! Apparently the distillation code is sensitive to the order in which modules are initialized 😔. I made a branch to demonstrate the minimal change needed to repro the test break: #4526 To my knowledge, module initialization order has no implications on training dynamics/model performance. So I'm inclined to call this a distillation bug and outside the scope of this PR. Anyone know differently? @EricMichaelSmith If I don't hear any objections by Thursday I'll merge as-is with the broken test and file an issue to fix the bug in the distillation code. |
Chatted with @EricMichaelSmith offline. He pointed out that this could be a result of a different order of random operations done during module initialization. I was able to confirm that's the problem, details in #4526 So I'm just going to update the test numbers and merge. |
Supporting decoder-only transformer training:
Summary of changes
TransformerDecoderLayer
without encoder attentionTransformerDecoder
that concatenates input and "encoder state"- Needed to pass the dictionary to
TransformerDecoder
to properly do concatenationPassThroughEncoder
, similar toIdentityLayer
but compatible with TransformerEncoder APITorchGeneratorAgent
, so the path of most code reuse led to using a dummy encoder to satisfy those assumptionsDecoderAgent
to overridebuild_model
TransformerDecoderOnly.forward
)query_len
rows from incremental attention to account for incremental step from context to first generated tokenDecoderIncrState
andDecoderLayerIncrState
type aliasesSummary of structural changes in
decoder.py
Testing steps
Also ran some small training runs locally to sanity check:
Benchmark Comparison with Encoder-Decoder
Both trained to 10k steps with 120M parameters (8 layers) on 2x Quadro GP100 GPUs. The encoder/decoder is closer to 140M parameters due to the cross-attention.