Skip to content

Commit

Permalink
add ProtectedString alias field, try it out in optimizer schemas. (#2757
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ksbrar authored Nov 14, 2022
1 parent 43eca26 commit a981caf
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
20 changes: 10 additions & 10 deletions ludwig/schema/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class SGDOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.SGD
"""Points to `torch.optim.SGD`."""

type: str = schema_utils.StringOptions(["sgd"], default="sgd", allow_none=False)
type: str = schema_utils.ProtectedString("sgd")
"""Must be 'sgd' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry` (default:
'sgd')"""

Expand All @@ -78,7 +78,7 @@ class LBFGSOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.LBFGS
"""Points to `torch.optim.LBFGS`."""

type: str = schema_utils.StringOptions(["lbfgs"], default="lbfgs", allow_none=False)
type: str = schema_utils.ProtectedString("lbfgs")
"""Must be 'lbfgs' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry` (default:
'lbfgs')"""

Expand Down Expand Up @@ -112,7 +112,7 @@ class AdamOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.Adam
"""Points to `torch.optim.Adam`."""

type: str = schema_utils.StringOptions(["adam"], default="adam", allow_none=False)
type: str = schema_utils.ProtectedString("adam")
"""Must be 'adam' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry`
(default: 'adam')"""

Expand Down Expand Up @@ -146,7 +146,7 @@ class AdamWOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.AdamW
"""Points to `torch.optim.AdamW`."""

type: str = schema_utils.StringOptions(["adamw"], default="adamw", allow_none=False)
type: str = schema_utils.ProtectedString("adamw")
"""Must be 'adamw' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry`
(default: 'adamw')"""

Expand Down Expand Up @@ -180,7 +180,7 @@ class AdadeltaOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.Adadelta
"""Points to `torch.optim.Adadelta`."""

type: str = schema_utils.StringOptions(["adadelta"], default="adadelta", allow_none=False)
type: str = schema_utils.ProtectedString("adadelta")
"""Must be 'adadelta' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry`
(default: 'adadelta')"""

Expand Down Expand Up @@ -213,7 +213,7 @@ class AdagradOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.Adagrad
"""Points to `torch.optim.Adagrad`."""

type: str = schema_utils.StringOptions(["adagrad"], default="adagrad", allow_none=False)
type: str = schema_utils.ProtectedString("adagrad")
"""Must be 'adagrad' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry`
(default: 'adagrad')"""

Expand All @@ -239,7 +239,7 @@ class AdamaxOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.Adamax
"""Points to `torch.optim.Adamax`."""

type: str = schema_utils.StringOptions(["adamax"], default="adamax", allow_none=False)
type: str = schema_utils.ProtectedString("adamax")
"""Must be 'adamax' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry`
(default: 'adamax')"""

Expand All @@ -263,7 +263,7 @@ class AdamaxOptimizerConfig(BaseOptimizerConfig):
class FtrlOptimizerConfig(BaseOptimizerConfig):

# optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.Ftrl
type: str = schema_utils.StringOptions(["ftrl"], default="ftrl", allow_none=False)
type: str = schema_utils.ProtectedString("ftrl")

learning_rate_power: float = schema_utils.FloatRange(default=-0.5, max=0.0)

Expand All @@ -281,7 +281,7 @@ class NadamOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.NAdam
"""Points to `torch.optim.NAdam`."""

type: str = schema_utils.StringOptions(["nadam"], default="nadam", allow_none=False)
type: str = schema_utils.ProtectedString("nadam")

# Defaults taken from https://pytorch.org/docs/stable/generated/torch.optim.NAdam.html#torch.optim.NAdam :

Expand All @@ -308,7 +308,7 @@ class RMSPropOptimizerConfig(BaseOptimizerConfig):
optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.RMSprop
"""Points to `torch.optim.RMSprop`."""

type: str = schema_utils.StringOptions(["rmsprop"], default="rmsprop", allow_none=False)
type: str = schema_utils.ProtectedString("rmsprop")
"""Must be 'rmsprop' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry`
(default: 'rmsprop')"""

Expand Down
18 changes: 18 additions & 0 deletions ludwig/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,24 @@ def StringOptions(
)


def ProtectedString(
pstring: str,
description: str = "",
parameter_metadata: ParameterMetadata = None,
):
"""Alias for a `StringOptions` field with only one option.
Useful primarily for `type` parameters.
"""
return StringOptions(
options=[pstring],
default=pstring,
allow_none=False,
description=description,
parameter_metadata=parameter_metadata,
)


def IntegerOptions(
options: TList[int],
default: Union[None, int] = None,
Expand Down

0 comments on commit a981caf

Please sign in to comment.