-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformer divisibility error validation #3105
Conversation
ludwig/config_validation/checks.py
Outdated
sequence_types = [SEQUENCE, TEXT, TIMESERIES] | ||
for sequence_type in sequence_types: | ||
encoder = config.defaults.__getattribute__(sequence_type).encoder | ||
if encoder.type == "transformer" and not is_divisible(encoder.hidden_size, encoder.num_heads): | ||
raise ConfigValidationError( | ||
f"Default {sequence_type} transformer encoder requires encoder.hidden_size to be divisible by " | ||
f"encoder.num_heads. Found hidden_size {encoder.hidden_size} and num_heads {encoder.num_heads}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might be able to get away without checking config.defaults
as all of these auxiliary checks are run after defaults
have been resolved into feature configs. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, cool! I'll remove the defaults block.
} | ||
|
||
with pytest.raises(ConfigValidationError): | ||
validate_config(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ModelConfig.from_dict(config)
instead of validate_config
, which I'm planning on removing in #3104
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So all of the tests fail when I use ModelConfig.from_dict
. Do we need to wait for the PR to land before updating?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffkinnison you're totally right! Let me try to land that PR first.
81d0be2
to
e325b86
Compare
Followup to #3066 that adds a validation check for
encoder.hidden_size % encoder.num_heads == 0
for transformer encoders.