Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Annotations] Annotate Schema Part 2: decoders, encoders, defaults, combiners, loss and preprocessing #2799

Merged
merged 3 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/common_transformer_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(repr=False, order=True)
class CommonTransformerConfig:
"""Common transformer parameter values."""
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(repr=False, order=True)
class ComparatorCombinerConfig(BaseCombinerConfig):
"""Parameters for comparator combiner."""
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(order=True, repr=False)
class ConcatCombinerConfig(BaseCombinerConfig):
"""Parameters for concat combiner."""
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/project_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(order=True, repr=False)
class ProjectAggregateCombinerConfig(BaseCombinerConfig):

Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import SEQUENCE
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
Expand All @@ -10,6 +11,7 @@
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(repr=False, order=True)
class SequenceCombinerConfig(BaseCombinerConfig):
"""Parameters for sequence combiner."""
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/sequence_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(repr=False)
class SequenceConcatCombinerConfig(BaseCombinerConfig):
"""Parameters for sequence concat combiner."""
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/tab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.common_transformer_options import CommonTransformerConfig
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(repr=False, order=True)
class TabTransformerCombinerConfig(BaseCombinerConfig, CommonTransformerConfig):
"""Parameters for tab transformer combiner."""
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/tabnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(repr=False, order=True)
class TabNetCombinerConfig(BaseCombinerConfig):
"""Parameters for tabnet combiner."""
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/combiners/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.common_transformer_options import CommonTransformerConfig
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA


@DeveloperAPI
@dataclass(repr=False, order=True)
class TransformerCombinerConfig(BaseCombinerConfig, CommonTransformerConfig):
"""Parameters for transformer combiner."""
Expand Down
4 changes: 4 additions & 0 deletions ludwig/schema/combiners/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import TYPE
from ludwig.schema import utils as schema_utils
from ludwig.schema.metadata.combiner_metadata import COMBINER_METADATA
Expand All @@ -7,6 +8,7 @@
combiner_registry = Registry()


@DeveloperAPI
def register_combiner(name: str):
def wrap(cls):
combiner_registry[name] = cls
Expand All @@ -15,6 +17,7 @@ def wrap(cls):
return wrap


@DeveloperAPI
def get_combiner_jsonschema():
"""Returns a JSON schema structured to only require a `type` key and then conditionally apply a corresponding
combiner's field constraints."""
Expand All @@ -37,6 +40,7 @@ def get_combiner_jsonschema():
}


@DeveloperAPI
def get_combiner_conds():
"""Returns a list of if-then JSON clauses for each combiner type in `combiner_registry` and its properties'
constraints."""
Expand Down
6 changes: 6 additions & 0 deletions ludwig/schema/decoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import BINARY, CATEGORY, NUMBER, SEQUENCE, SET, TEXT, VECTOR
from ludwig.schema import utils as schema_utils
from ludwig.schema.decoders.utils import register_decoder_config
from ludwig.schema.metadata.decoder_metadata import DECODER_METADATA


@DeveloperAPI
@dataclass(repr=False)
class BaseDecoderConfig(schema_utils.BaseMarshmallowConfig, ABC):
"""Base class for decoders."""
Expand Down Expand Up @@ -54,6 +56,7 @@ class BaseDecoderConfig(schema_utils.BaseMarshmallowConfig, ABC):
)


@DeveloperAPI
@register_decoder_config("passthrough", [BINARY, CATEGORY, NUMBER, SET, VECTOR, SEQUENCE, TEXT])
@dataclass(repr=False)
class PassthroughDecoderConfig(BaseDecoderConfig):
Expand All @@ -67,6 +70,7 @@ class PassthroughDecoderConfig(BaseDecoderConfig):
)


@DeveloperAPI
@register_decoder_config("regressor", [BINARY, NUMBER])
@dataclass(repr=False)
class RegressorConfig(BaseDecoderConfig):
Expand Down Expand Up @@ -103,6 +107,7 @@ class RegressorConfig(BaseDecoderConfig):
)


@DeveloperAPI
@register_decoder_config("projector", [VECTOR])
@dataclass(repr=False)
class ProjectorConfig(BaseDecoderConfig):
Expand Down Expand Up @@ -161,6 +166,7 @@ class ProjectorConfig(BaseDecoderConfig):
)


@DeveloperAPI
@register_decoder_config("classifier", [CATEGORY, SET])
@dataclass(repr=False)
class ClassifierConfig(BaseDecoderConfig):
Expand Down
3 changes: 3 additions & 0 deletions ludwig/schema/decoders/sequence_decoders.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import SEQUENCE, TEXT
from ludwig.schema import utils as schema_utils
from ludwig.schema.decoders.base import BaseDecoderConfig
from ludwig.schema.decoders.utils import register_decoder_config
from ludwig.schema.metadata.decoder_metadata import DECODER_METADATA


@DeveloperAPI
@register_decoder_config("generator", [SEQUENCE, TEXT])
@dataclass(repr=False)
class SequenceGeneratorDecoderConfig(BaseDecoderConfig):
Expand Down Expand Up @@ -58,6 +60,7 @@ class SequenceGeneratorDecoderConfig(BaseDecoderConfig):
)


@DeveloperAPI
@register_decoder_config("tagger", [SEQUENCE, TEXT])
@dataclass(repr=False)
class SequenceTaggerDecoderConfig(BaseDecoderConfig):
Expand Down
6 changes: 6 additions & 0 deletions ludwig/schema/decoders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

from marshmallow import fields, ValidationError

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import TYPE
from ludwig.schema import utils as schema_utils
from ludwig.utils.registry import Registry

decoder_config_registry = Registry()


@DeveloperAPI
def register_decoder_config(name: str, features: Union[str, List[str]]):
if isinstance(features, str):
features = [features]
Expand All @@ -24,14 +26,17 @@ def wrap(cls):
return wrap


@DeveloperAPI
def get_decoder_cls(feature: str, name: str):
return decoder_config_registry[feature][name]


@DeveloperAPI
def get_decoder_classes(feature: str):
return decoder_config_registry[feature]


@DeveloperAPI
def get_decoder_conds(feature_type: str):
"""Returns a JSON schema of conditionals to validate against decoder types for specific feature types."""
conds = []
Expand All @@ -47,6 +52,7 @@ def get_decoder_conds(feature_type: str):
return conds


@DeveloperAPI
def DecoderDataclassField(feature_type: str, default: str):
"""Custom dataclass field that when used inside a dataclass will allow the user to specify a decoder config.

Expand Down
3 changes: 3 additions & 0 deletions ludwig/schema/defaults/defaults.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import (
AUDIO,
BAG,
Expand All @@ -20,6 +21,7 @@
from ludwig.schema.features.base import BaseFeatureConfig


@DeveloperAPI
@dataclass
class DefaultsConfig(schema_utils.BaseMarshmallowConfig):

Expand Down Expand Up @@ -50,6 +52,7 @@ class DefaultsConfig(schema_utils.BaseMarshmallowConfig):
vector: BaseFeatureConfig = DefaultsDataclassField(feature_type=VECTOR)


@DeveloperAPI
def get_defaults_jsonschema():
"""Returns a JSON schema structured to only require a `type` key and then conditionally apply a corresponding
combiner's field constraints."""
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/defaults/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from marshmallow import fields, ValidationError

import ludwig.schema.utils as schema_utils
from ludwig.api_annotations import DeveloperAPI
from ludwig.schema.features.utils import input_mixin_registry, output_mixin_registry


@DeveloperAPI
def DefaultsDataclassField(feature_type: str):
"""Custom dataclass field that when used inside a dataclass will allow the user to specify a nested default
config for a specific feature type.
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/encoders/bag_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import BAG
from ludwig.schema import utils as schema_utils
from ludwig.schema.encoders.base import BaseEncoderConfig
from ludwig.schema.encoders.utils import register_encoder_config
from ludwig.schema.metadata.encoder_metadata import ENCODER_METADATA


@DeveloperAPI
@register_encoder_config("embed", BAG)
@dataclass(repr=False, order=True)
class BagEmbedWeightedConfig(BaseEncoderConfig):
Expand Down
6 changes: 5 additions & 1 deletion ludwig/schema/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@

from marshmallow_dataclass import dataclass

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import BINARY, NUMBER, VECTOR
from ludwig.schema import utils as schema_utils
from ludwig.schema.encoders.utils import register_encoder_config


@DeveloperAPI
@dataclass(repr=False, order=True)
class BaseEncoderConfig(schema_utils.BaseMarshmallowConfig, ABC):
"""Base class for encoders."""

type: str


@DeveloperAPI
@register_encoder_config("passthrough", [NUMBER, VECTOR])
@dataclass(order=True)
class PassthroughEncoderConfig(BaseEncoderConfig):
Expand All @@ -28,6 +31,7 @@ class PassthroughEncoderConfig(BaseEncoderConfig):
)


@DeveloperAPI
@register_encoder_config("dense", [BINARY, NUMBER, VECTOR])
@dataclass(repr=False, order=True)
class DenseEncoderConfig(BaseEncoderConfig):
Expand Down Expand Up @@ -77,7 +81,7 @@ class DenseEncoderConfig(BaseEncoderConfig):
description="Initializer for the weight matrix.",
)

norm: Union[str] = schema_utils.StringOptions(
norm: str = schema_utils.StringOptions(
["batch", "layer"],
allow_none=True,
default=None,
Expand Down
2 changes: 2 additions & 0 deletions ludwig/schema/encoders/binary_encoders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from marshmallow_dataclass import dataclass

import ludwig.schema.utils as schema_utils
from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import BINARY
from ludwig.schema.encoders.base import BaseEncoderConfig
from ludwig.schema.encoders.utils import register_encoder_config


@DeveloperAPI
@register_encoder_config("passthrough", BINARY)
@dataclass(repr=False)
class BinaryPassthroughEncoderConfig(BaseEncoderConfig):
Expand Down
Loading