Skip to content

Commit 67b5a61

Browse files
[Enh] add validation for hydra config (#769)
* add validation for hydra config * update unitest for pydantic * fix for OptimizerList * fix
1 parent 254278a commit 67b5a61

File tree

13 files changed

+486
-18
lines changed

13 files changed

+486
-18
lines changed

ppsci/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def build_dataloader(_dataset, cfg):
160160
num_workers=cfg.get("num_workers", _DEFAULT_NUM_WORKERS),
161161
use_shared_memory=cfg.get("use_shared_memory", False),
162162
worker_init_fn=init_fn,
163-
# TODO: Do not enable persistent_workers' below for
163+
# TODO: Do not enable 'persistent_workers' below for
164164
# 'IndexError: pop from empty list ...' will be raised in certain cases
165165
# persistent_workers=cfg.get("num_workers", _DEFAULT_NUM_WORKERS) > 0,
166166
)

ppsci/optimizer/lr_scheduler.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -739,20 +739,17 @@ class SchedulerList:
739739
"""SchedulerList which wrap more than one scheduler.
740740
Args:
741741
scheduler_list (Tuple[lr.LRScheduler, ...]): Schedulers listed in a tuple.
742-
by_epoch (bool, optional): Learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
742+
743743
Examples:
744744
>>> import ppsci
745745
>>> sch1 = ppsci.optimizer.lr_scheduler.Linear(10, 2, 0.001)()
746746
>>> sch2 = ppsci.optimizer.lr_scheduler.ExponentialDecay(10, 2, 1e-3, 0.95, 3)()
747747
>>> sch = ppsci.optimizer.lr_scheduler.SchedulerList((sch1, sch2))
748748
"""
749749

750-
def __init__(
751-
self, scheduler_list: Tuple[lr.LRScheduler, ...], by_epoch: bool = False
752-
):
750+
def __init__(self, scheduler_list: Tuple[lr.LRScheduler, ...]):
753751
super().__init__()
754752
self._sch_list = scheduler_list
755-
self.by_epoch = by_epoch
756753

757754
def step(self):
758755
for sch in self._sch_list:

ppsci/optimizer/optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,3 +525,6 @@ def __getitem__(self, idx):
525525

526526
def __setitem__(self, idx, opt):
527527
raise NotImplementedError("Can not modify any item in OptimizerList.")
528+
529+
def __iter__(self):
530+
yield from iter(self._opt_list)

ppsci/solver/eval.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ def _eval_by_dataset(
128128
solver.eval_time_info["batch_cost"].update(batch_cost)
129129
batch_size = next(iter(input_dict.values())).shape[0]
130130
printer.update_eval_loss(solver, loss_dict, batch_size)
131-
if iter_id == 1 or iter_id % log_freq == 0:
131+
if (
132+
iter_id == 1
133+
or iter_id % log_freq == 0
134+
or iter_id == len(_validator.data_loader)
135+
):
132136
printer.log_eval_info(
133137
solver,
134138
batch_size,
@@ -247,7 +251,11 @@ def _eval_by_batch(
247251
solver.eval_time_info["reader_cost"].update(reader_cost)
248252
solver.eval_time_info["batch_cost"].update(batch_cost)
249253
printer.update_eval_loss(solver, loss_dict, batch_size)
250-
if iter_id == 1 or iter_id % log_freq == 0:
254+
if (
255+
iter_id == 1
256+
or iter_id % log_freq == 0
257+
or iter_id == len(_validator.data_loader)
258+
):
251259
printer.log_eval_info(
252260
solver,
253261
batch_size,

ppsci/solver/solver.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class Solver:
7474
validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None.
7575
visualizer (Optional[Dict[str, ppsci.visualize.Visualizer]]): Visualizer dict. Defaults to None.
7676
use_amp (bool, optional): Whether use AMP. Defaults to False.
77-
amp_level (Literal["O1", "O2", "O0"], optional): AMP level. Defaults to "O0".
77+
amp_level (Literal["O0", "O1", "O2", "OD"], optional): AMP level. Defaults to "O1".
7878
pretrained_model_path (Optional[str]): Pretrained model path. Defaults to None.
7979
checkpoint_path (Optional[str]): Checkpoint path. Defaults to None.
8080
compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluation. Defaults to False.
@@ -86,7 +86,7 @@ class Solver:
8686
Examples:
8787
>>> import ppsci
8888
>>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20)
89-
>>> opt = ppsci.optimizer.AdamW(1e-3)((model,))
89+
>>> opt = ppsci.optimizer.AdamW(1e-3)(model)
9090
>>> geom = ppsci.geometry.Rectangle((0, 0), (1, 1))
9191
>>> pde_constraint = ppsci.constraint.InteriorConstraint(
9292
... {"u": lambda out: out["u"]},
@@ -134,7 +134,7 @@ def __init__(
134134
validator: Optional[Dict[str, ppsci.validate.Validator]] = None,
135135
visualizer: Optional[Dict[str, ppsci.visualize.Visualizer]] = None,
136136
use_amp: bool = False,
137-
amp_level: Literal["O1", "O2", "O0"] = "O0",
137+
amp_level: Literal["O0", "O1", "O2", "OD"] = "O1",
138138
pretrained_model_path: Optional[str] = None,
139139
checkpoint_path: Optional[str] = None,
140140
compute_metric_by_batch: bool = False,
@@ -152,7 +152,28 @@ def __init__(
152152
# set optimizer
153153
self.optimizer = optimizer
154154
# set learning rate scheduler
155-
self.lr_scheduler = lr_scheduler
155+
if lr_scheduler is not None:
156+
logger.warning(
157+
"The argument: 'lr_scheduler' now automatically retrieves from "
158+
"'optimizer._learning_rate' when 'optimizer' is given, so it is "
159+
"recommended to remove it from the Solver's initialization arguments."
160+
)
161+
self.lr_scheduler = (
162+
optimizer._learning_rate
163+
if (
164+
isinstance(optimizer, optim.Optimizer)
165+
and isinstance(optimizer._learning_rate, optim.lr.LRScheduler)
166+
)
167+
else None
168+
)
169+
if isinstance(self.optimizer, ppsci.optimizer.OptimizerList):
170+
self.lr_scheduler = ppsci.optimizer.lr_scheduler.SchedulerList(
171+
tuple(
172+
opt._learning_rate
173+
for opt in self.optimizer
174+
if isinstance(opt._learning_rate, optim.lr.LRScheduler)
175+
)
176+
)
156177

157178
# set training hyper-parameter
158179
self.epochs = epochs

ppsci/utils/callbacks.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib.util
16+
import inspect
17+
import sys
1518
from os import path as osp
1619
from typing import Any
1720

1821
from hydra.experimental.callback import Callback
1922
from omegaconf import DictConfig
2023

24+
from ppsci.utils import config as config_module
2125
from ppsci.utils import logger
2226
from ppsci.utils import misc
2327

28+
RUNTIME_EXIT_CODE = 1 # for other errors
29+
VALIDATION_ERROR_EXIT_CODE = 2 # for invalid argument detected in config file
30+
2431

2532
class InitCallback(Callback):
2633
"""Callback class for:
27-
1. Fixing random seed to 'config.seed'
28-
2. Initialize logger while creating output directory(if not exist).
34+
1. Parse config dict from given yaml file and check its validity, complete missing items by its' default values.
35+
2. Fixing random seed to 'config.seed'.
36+
3. Initialize logger while creating output directory(if not exist).
2937
3038
NOTE: This callback is mainly for reducing unnecessary duplicate code in each
3139
examples code when runing with hydra.
@@ -52,10 +60,40 @@ class InitCallback(Callback):
5260
"""
5361

5462
def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
63+
# check given cfg using pre-defined pydantic schema in 'SolverConfig', error(s) will be raised
64+
# if any checking failed at this step
65+
if importlib.util.find_spec("pydantic") is not None:
66+
from pydantic import ValidationError
67+
else:
68+
logger.error(
69+
f"ModuleNotFoundError at {__file__}:{inspect.currentframe().f_lineno}\n"
70+
"Please install pydantic with `pip install pydantic` when set callbacks"
71+
" in your config yaml."
72+
)
73+
sys.exit(RUNTIME_EXIT_CODE)
74+
75+
# check given cfg using pre-defined pydantic schema in 'SolverConfig',
76+
# error(s) will be printed and exit program if any checking failed at this step
77+
try:
78+
_model_pydantic = config_module.SolverConfig(**dict(config))
79+
# complete missing items with default values pre-defined in pydantic schema in
80+
# 'SolverConfig'
81+
full_cfg = DictConfig(_model_pydantic.model_dump())
82+
except ValidationError as e:
83+
print(e)
84+
sys.exit(VALIDATION_ERROR_EXIT_CODE)
85+
except Exception as e:
86+
print(e)
87+
sys.exit(RUNTIME_EXIT_CODE)
88+
5589
# fix random seed for reproducibility
56-
misc.set_random_seed(config.seed)
90+
misc.set_random_seed(full_cfg.seed)
5791

58-
# create output directory
92+
# initialze logger while creating output directory
5993
logger.init_logger(
60-
"ppsci", osp.join(config.output_dir, f"{config.mode}.log"), "info"
94+
"ppsci",
95+
osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log")
96+
if full_cfg.output_dir
97+
else None,
98+
full_cfg.log_level,
6199
)

0 commit comments

Comments
 (0)