@@ -74,7 +74,7 @@ class Solver:
7474 validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None.
7575 visualizer (Optional[Dict[str, ppsci.visualize.Visualizer]]): Visualizer dict. Defaults to None.
7676 use_amp (bool, optional): Whether use AMP. Defaults to False.
77- amp_level (Literal["O1", "O2", "O0 "], optional): AMP level. Defaults to "O0 ".
77+ amp_level (Literal["O0", " O1", "O2", "OD "], optional): AMP level. Defaults to "O1 ".
7878 pretrained_model_path (Optional[str]): Pretrained model path. Defaults to None.
7979 checkpoint_path (Optional[str]): Checkpoint path. Defaults to None.
8080 compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluation. Defaults to False.
@@ -86,7 +86,7 @@ class Solver:
8686 Examples:
8787 >>> import ppsci
8888 >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20)
89- >>> opt = ppsci.optimizer.AdamW(1e-3)(( model,) )
89+ >>> opt = ppsci.optimizer.AdamW(1e-3)(model)
9090 >>> geom = ppsci.geometry.Rectangle((0, 0), (1, 1))
9191 >>> pde_constraint = ppsci.constraint.InteriorConstraint(
9292 ... {"u": lambda out: out["u"]},
@@ -134,7 +134,7 @@ def __init__(
134134 validator : Optional [Dict [str , ppsci .validate .Validator ]] = None ,
135135 visualizer : Optional [Dict [str , ppsci .visualize .Visualizer ]] = None ,
136136 use_amp : bool = False ,
137- amp_level : Literal ["O1" , "O2" , "O0 " ] = "O0 " ,
137+ amp_level : Literal ["O0" , " O1" , "O2" , "OD " ] = "O1 " ,
138138 pretrained_model_path : Optional [str ] = None ,
139139 checkpoint_path : Optional [str ] = None ,
140140 compute_metric_by_batch : bool = False ,
@@ -152,7 +152,28 @@ def __init__(
152152 # set optimizer
153153 self .optimizer = optimizer
154154 # set learning rate scheduler
155- self .lr_scheduler = lr_scheduler
155+ if lr_scheduler is not None :
156+ logger .warning (
157+ "The argument: 'lr_scheduler' now automatically retrieves from "
158+ "'optimizer._learning_rate' when 'optimizer' is given, so it is "
159+ "recommended to remove it from the Solver's initialization arguments."
160+ )
161+ self .lr_scheduler = (
162+ optimizer ._learning_rate
163+ if (
164+ isinstance (optimizer , optim .Optimizer )
165+ and isinstance (optimizer ._learning_rate , optim .lr .LRScheduler )
166+ )
167+ else None
168+ )
169+ if isinstance (self .optimizer , ppsci .optimizer .OptimizerList ):
170+ self .lr_scheduler = ppsci .optimizer .lr_scheduler .SchedulerList (
171+ tuple (
172+ opt ._learning_rate
173+ for opt in self .optimizer
174+ if isinstance (opt ._learning_rate , optim .lr .LRScheduler )
175+ )
176+ )
156177
157178 # set training hyper-parameter
158179 self .epochs = epochs
0 commit comments