-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactored ModelConfig object into a Marshmallow schema #2906
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…into refactor-schema
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, definitely a lot of terrific changes in here. Thanks for putting this together Travis. I just had a few nits I noticed while going through this but at a high level the logic all makes sense and doesn't have any major errors. The one thing I'll callout is that we may need to double check that there are no changes needed on the Predibase side to reconcile these changes. It looks like the api has been preserved so I don't expect anything here, but I do remember running into issues last time making these changes so just wanted to mention it. Overall, LGTM!
"description": "Type of the input feature", | ||
}, | ||
"column": {"type": "string", "title": "column", "description": "Name of the column."}, | ||
"type": "object", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're no longer treating this as an array with "minItems": 1
will this still validate that the user provided at least one feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, it should get it from the FeaturesTypeSelection
which has min_length=1
set by defaults. But let me check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, looks like it works. Added tests to verify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome that it works, also verified it locally. I'm a little confused about how this works though... the custom schema here is no longer even an array so how does it end up being one in the final assembly (let alone how does the allOf
end up in the right place)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's marshmallow-jsonschema
magic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tgaddair I'm not so sure. Whatever is set in _json_type_mapping
should override any kind of automatic logic so at minimum while I might expect it to perhaps fix the type
to be an array
I would not expect it to insert items
or know to cherry-pick allOf
to one level higher than it is in the final schema. But I actually wanted to talk about this in our next sync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just figured this out. Despite the fact that json_type_mapping
is supposed to be the ultimate override method, marshmallow_jsonschema
will change the type of the field if the class in question derives a native marshmallow type. Since FeaturesTypeSelection
ultimately derives marshmallow.fields.List
, the package will force its JSON schema into proper list/array format - in particular by taking the output of json_type_mapping
and shoving it inside of a standard list's JSON schema (that is where the surrounding type: array : { items: { ...
comes from).
We are lucky that all of the feature's schema attributes and conditional logic indeed are supposed to be nested inside of items
, as otherwise this would have caused a major error!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting! Nice discovery.
ludwig/schema/combiners/base.py
Outdated
from ludwig.api_annotations import DeveloperAPI | ||
from ludwig.schema import utils as schema_utils | ||
|
||
|
||
@DeveloperAPI | ||
@dataclass(repr=False, order=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we import and use ludwig_dataclass here instead?
from ludwig.schema.utils import ludwig_dataclass
@ludwig_dataclass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. What about BaseFeatureConfig
and the other classes in that file? Does it matter that order=True
isn't specified on those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya we should use @ludwig_dataclass
on those as well. Currently the it's not a huge deal since a lot of the parameters on those classes are filtered out or not shown for one reason or another. However, we will need it eventually and having order=True
doesn't hurt so I think adding it in is a great idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super, super clean. a few nits... and questions 🤔
from ludwig.constants import MODEL_ECD, MODEL_TYPE, PREPROCESSING, SPLIT | ||
from ludwig.schema import utils as schema_utils | ||
|
||
# TODO(travis): figure out why we need these imports to avoid circular import error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually the case even with validation moved to a separate module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, unfortunately so. I think the right fix here is to remove the circular import structure so schema has no deps on the encoders, etc. But this is a larger refctor.
error = e | ||
|
||
if error is not None: | ||
raise ValidationError(f"Failed to validate JSON schema for config. Error: {error.message}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
class BinaryInputFeatureConfig(BaseInputFeatureConfig, BinaryInputFeatureConfigMixin): | ||
"""BinaryInputFeatureConfig is a dataclass that configures the parameters used for a binary input feature.""" | ||
|
||
encoder: BaseEncoderConfig = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming that without setting = None
here you get a bunch of dataclass parameter order initialization errors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think it has to do with the "can't have a default param before a require param" stuff.
@@ -109,12 +105,14 @@ class CommonSchedulerOptions: | |||
), | |||
) | |||
|
|||
max_t: int = max_t_alias() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_t
is not on every scheduler, e.g: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#median-stopping-rule-tune-schedulers-medianstoppingrule
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reason why I had to move it here is because we run some checks against it for hyperopt defaults, but I can rework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
"description": "Type of the input feature", | ||
}, | ||
"column": {"type": "string", "title": "column", "description": "Name of the column."}, | ||
"type": "object", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome that it works, also verified it locally. I'm a little confused about how this works though... the custom schema here is no longer even an array so how does it end up being one in the final assembly (let alone how does the allOf
end up in the right place)?
f"Invalid params for optimizer: {value}, expect dict with at least a valid `type` attribute." | ||
f"Invalid optimizer type: '{opt_type}', expected one of: {list(optimizer_registry.keys())}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much more useful error -> resolution messages, thanks for adding these in
else: | ||
return get_default_decoder_type(feature[TYPE]) | ||
|
||
|
||
def has_trainable_encoder(config: ModelConfig) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would it make sense to also move has_trainable_encoder
and has_pretrained_encoder
into ludwig/schema/model_types/utils.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can do it in a follow-up.
from typing import Any, Dict, Optional | ||
|
||
import pytest | ||
|
||
from ludwig.utils.config_utils import merge_fixed_preprocessing_params | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just delete this test module? I'm assuming you've commented out these tests because these functions are only used within the ModelConfig object, and they seem to work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to revisit in a follow-up to make it work with the new structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cleans up all the schema implementations we've had by a lot, thanks for putting this together 👏
Only minor nits, but looks good otherwise!
No description provided.