1818
1919# This file is for step1: training a embedding model.
2020# This file is based on PaddleScience/ppsci API.
21+ from os import path as osp
2122
2223import hydra
2324import numpy as np
2425import paddle
2526from omegaconf import DictConfig
2627
2728import ppsci
29+ from ppsci .utils import logger
2830
2931
3032def get_mean_std (data : np .ndarray ):
@@ -38,6 +40,11 @@ def get_mean_std(data: np.ndarray):
3840
3941
4042def train (cfg : DictConfig ):
43+ # set random seed for reproducibility
44+ ppsci .utils .misc .set_random_seed (cfg .seed )
45+ # initialize logger
46+ logger .init_logger ("ppsci" , osp .join (cfg .output_dir , f"{ cfg .mode } .log" ), "info" )
47+
4148 weights = (1.0 * (cfg .TRAIN_BLOCK_SIZE - 1 ), 1.0e4 * cfg .TRAIN_BLOCK_SIZE )
4249 regularization_key = "k_matrix"
4350 # manually build constraint(s)
@@ -130,9 +137,12 @@ def train(cfg: DictConfig):
130137 solver = ppsci .solver .Solver (
131138 model ,
132139 constraint ,
133- optimizer = optimizer ,
140+ cfg .output_dir ,
141+ optimizer ,
142+ epochs = cfg .TRAIN .epochs ,
143+ iters_per_epoch = ITERS_PER_EPOCH ,
144+ eval_during_train = True ,
134145 validator = validator ,
135- cfg = cfg ,
136146 )
137147 # train model
138148 solver .train ()
@@ -141,6 +151,11 @@ def train(cfg: DictConfig):
141151
142152
143153def evaluate (cfg : DictConfig ):
154+ # set random seed for reproducibility
155+ ppsci .utils .misc .set_random_seed (cfg .seed )
156+ # initialize logger
157+ logger .init_logger ("ppsci" , osp .join (cfg .output_dir , f"{ cfg .mode } .log" ), "info" )
158+
144159 weights = (1.0 * (cfg .TRAIN_BLOCK_SIZE - 1 ), 1.0e4 * cfg .TRAIN_BLOCK_SIZE )
145160 regularization_key = "k_matrix"
146161 # manually build constraint(s)
0 commit comments