From 66902bc00e961e74fef1c05c68c4cf871296cbab Mon Sep 17 00:00:00 2001 From: xusuyong <2209245477@qq.com> Date: Mon, 6 May 2024 16:28:22 +0800 Subject: [PATCH 1/7] add allen cahn ntk --- examples/allen_cahn/allen_cahn_defalut_ntk.py | 332 ++++++++++++++++++ examples/allen_cahn/allen_cahn_sota.py | 331 +++++++++++++++++ .../conf/allen_cahn_defalut_ntk.yaml | 96 +++++ examples/allen_cahn/conf/allen_cahn_sota.yaml | 96 +++++ ppsci/arch/mlp.py | 44 ++- ppsci/loss/mtl/__init__.py | 2 + ppsci/loss/mtl/ntk.py | 111 ++++++ 7 files changed, 1007 insertions(+), 5 deletions(-) create mode 100644 examples/allen_cahn/allen_cahn_defalut_ntk.py create mode 100644 examples/allen_cahn/allen_cahn_sota.py create mode 100644 examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml create mode 100644 examples/allen_cahn/conf/allen_cahn_sota.yaml create mode 100644 ppsci/loss/mtl/ntk.py diff --git a/examples/allen_cahn/allen_cahn_defalut_ntk.py b/examples/allen_cahn/allen_cahn_defalut_ntk.py new file mode 100644 index 0000000000..95e2909ff4 --- /dev/null +++ b/examples/allen_cahn/allen_cahn_defalut_ntk.py @@ -0,0 +1,332 @@ +""" +Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn +""" + +from os import path as osp + +import hydra +import numpy as np +import paddle +import scipy.io as sio +from matplotlib import pyplot as plt +from omegaconf import DictConfig + +import ppsci +from ppsci.loss import mtl +from ppsci.utils import misc + +dtype = paddle.get_default_dtype() + + +def plot( + t_star: np.ndarray, + x_star: np.ndarray, + u_ref: np.ndarray, + u_pred: np.ndarray, + output_dir: str, +): + fig = plt.figure(figsize=(18, 5)) + TT, XX = np.meshgrid(t_star, x_star, indexing="ij") + u_ref = u_ref.reshape([len(t_star), len(x_star)]) + + plt.subplot(1, 3, 1) + plt.pcolor(TT, XX, u_ref, cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Exact") + plt.tight_layout() + + plt.subplot(1, 3, 2) + plt.pcolor(TT, XX, u_pred, cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Predicted") + plt.tight_layout() + + plt.subplot(1, 3, 3) + plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Absolute error") + plt.tight_layout() + + fig_path = osp.join(output_dir, "ac.png") + print(f"Saving figure to {fig_path}") + fig.savefig(fig_path, bbox_inches="tight", dpi=400) + plt.close() + + +def train(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + jaxsd = np.load("allen_cahn_init_weight.npz") + + sd = {} + sd["fourier_emb.kernel"] = paddle.to_tensor(jaxsd[".FourierEmbs_0.kernel"][0]) + sd["linears.0.weight_v"] = paddle.to_tensor(jaxsd[".Dense_0.kernel.1"][0]) + sd["linears.0.weight_g"] = paddle.to_tensor(jaxsd[".Dense_0.kernel.0"][0]) + sd["linears.0.bias"] = paddle.to_tensor(jaxsd[".Dense_0.bias"][0]) + + sd["linears.1.weight_v"] = paddle.to_tensor(jaxsd[".Dense_1.kernel.1"][0]) + sd["linears.1.weight_g"] = paddle.to_tensor(jaxsd[".Dense_1.kernel.0"][0]) + sd["linears.1.bias"] = paddle.to_tensor(jaxsd[".Dense_1.bias"][0]) + + sd["linears.2.weight_v"] = paddle.to_tensor(jaxsd[".Dense_2.kernel.1"][0]) + sd["linears.2.weight_g"] = paddle.to_tensor(jaxsd[".Dense_2.kernel.0"][0]) + sd["linears.2.bias"] = paddle.to_tensor(jaxsd[".Dense_2.bias"][0]) + + sd["linears.3.weight_v"] = paddle.to_tensor(jaxsd[".Dense_3.kernel.1"][0]) + sd["linears.3.weight_g"] = paddle.to_tensor(jaxsd[".Dense_3.kernel.0"][0]) + sd["linears.3.bias"] = paddle.to_tensor(jaxsd[".Dense_3.bias"][0]) + + sd["last_fc.weight_v"] = paddle.to_tensor(jaxsd[".Dense_4.kernel.1"][0]) + sd["last_fc.weight_g"] = paddle.to_tensor(jaxsd[".Dense_4.kernel.0"][0]) + sd["last_fc.bias"] = paddle.to_tensor(jaxsd[".Dense_4.bias"][0]) + model.set_state_dict(sd) + + # set equation + equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)} + + # set constraint + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + + u0 = u_ref[0, :] # [nx, ] + + t0 = t_star[0] # float + t1 = t_star[-1] # float + + x0 = x_star[0] # float + x1 = x_star[-1] # float + + def gen_input_batch(): + tx = np.random.uniform( + [t0, x0], + [t1, x1], + (cfg.TRAIN.batch_size, 2), + ).astype(dtype) + return { + "t": np.sort(tx[:, 0:1], axis=0), + "x": tx[:, 1:2], + } + + def gen_label_batch(input_batch): + return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)} + + pde_constraint = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "ContinuousNamedArrayDataset", + "input": gen_input_batch, + "label": gen_label_batch, + }, + }, + output_expr=equation["AllenCahn"].equations, + loss=ppsci.loss.CausalMSELoss( + cfg.TRAIN.causal.n_chunks, "mean", tol=cfg.TRAIN.causal.tol + ), + name="PDE", + ) + + ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])} + ic_label = {"u": u0.reshape([-1, 1])} + ic = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "IterableNamedArrayDataset", + "input": ic_input, + "label": ic_label, + }, + }, + output_expr={"u": lambda out: out["u"]}, + loss=ppsci.loss.MSELoss("mean"), + name="IC", + ) + # wrap constraints together + constraint = { + pde_constraint.name: pde_constraint, + ic.name: ic, + } + + # set optimizer + lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay( + **cfg.TRAIN.lr_scheduler + )() + optimizer = ppsci.optimizer.Adam(lr_scheduler)(model) + + # set validator + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + eval_label = {"u": u_ref.reshape([-1, 1])} + u_validator = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": eval_data, + "label": eval_label, + }, + "batch_size": cfg.EVAL.batch_size, + }, + ppsci.loss.MSELoss("mean"), + {"u": lambda out: out["u"]}, + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="u_validator", + ) + validator = {u_validator.name: u_validator} + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + cfg.output_dir, + optimizer, + epochs=cfg.TRAIN.epochs, + iters_per_epoch=cfg.TRAIN.iters_per_epoch, + save_freq=cfg.TRAIN.save_freq, + log_freq=cfg.log_freq, + eval_during_train=True, + eval_freq=cfg.TRAIN.eval_freq, + equation=equation, + validator=validator, + pretrained_model_path=cfg.TRAIN.pretrained_model_path, + checkpoint_path=cfg.TRAIN.checkpoint_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + loss_aggregator=mtl.NTK( + model, + len(constraint), + cfg.TRAIN.ntk.update_freq, + # cfg.TRAIN.ntk.momentum, + ), + cfg=cfg, + ) + # train model + solver.train() + # evaluate after finished training + solver.eval() + # visualize prediction after finished training + u_pred = solver.predict( + eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True + )["u"] + u_pred = u_pred.reshape([len(t_star), len(x_star)]) + + # plot + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +def evaluate(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + + # set validator + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + eval_label = {"u": u_ref.reshape([-1, 1])} + u_validator = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": eval_data, + "label": eval_label, + }, + "batch_size": cfg.EVAL.batch_size, + }, + ppsci.loss.MSELoss("mean"), + {"u": lambda out: out["u"]}, + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="u_validator", + ) + validator = {u_validator.name: u_validator} + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + log_freq=cfg.log_freq, + validator=validator, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + + # evaluate after finished training + solver.eval() + # visualize prediction after finished training + u_pred = solver.predict( + eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True + )["u"] + u_pred = u_pred.reshape([len(t_star), len(x_star)]) + + # plot + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +def export(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, + ] + solver.export(input_spec, cfg.INFER.export_path, with_onnx=False) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + predictor = pinn_predictor.PINNPredictor(cfg) + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + + input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) + } + u_pred = output_dict["u"].reshape([len(t_star), len(x_star)]) + # mapping data to cfg.INFER.output_keys + + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +@hydra.main( + version_base=None, config_path="./conf", config_name="allen_cahn_defalut_ntk.yaml" +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) + else: + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/allen_cahn/allen_cahn_sota.py b/examples/allen_cahn/allen_cahn_sota.py new file mode 100644 index 0000000000..3226431014 --- /dev/null +++ b/examples/allen_cahn/allen_cahn_sota.py @@ -0,0 +1,331 @@ +""" +Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn +""" + +from os import path as osp + +import hydra +import numpy as np +import paddle +import scipy.io as sio +from matplotlib import pyplot as plt +from omegaconf import DictConfig + +import ppsci +from ppsci.loss import mtl +from ppsci.utils import misc + +# paddle.device.set_device("cpu") +dtype = paddle.get_default_dtype() + + +def plot( + t_star: np.ndarray, + x_star: np.ndarray, + u_ref: np.ndarray, + u_pred: np.ndarray, + output_dir: str, +): + fig = plt.figure(figsize=(18, 5)) + TT, XX = np.meshgrid(t_star, x_star, indexing="ij") + u_ref = u_ref.reshape([len(t_star), len(x_star)]) + + plt.subplot(1, 3, 1) + plt.pcolor(TT, XX, u_ref, cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Exact") + plt.tight_layout() + + plt.subplot(1, 3, 2) + plt.pcolor(TT, XX, u_pred, cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Predicted") + plt.tight_layout() + + plt.subplot(1, 3, 3) + plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet") + plt.colorbar() + plt.xlabel("t") + plt.ylabel("x") + plt.title("Absolute error") + plt.tight_layout() + + fig_path = osp.join(output_dir, "ac.png") + print(f"Saving figure to {fig_path}") + fig.savefig(fig_path, bbox_inches="tight", dpi=400) + plt.close() + + +def train(cfg: DictConfig): + # set model + model = ppsci.arch.ModifiedMLP(**cfg.MODEL) + + jaxsd = np.load("allen_cahn_init_weight.npz") + + sd = {} + sd["fourier_emb.kernel"] = paddle.to_tensor(jaxsd[".FourierEmbs_0.kernel"][0]) + sd["linears.0.weight_v"] = paddle.to_tensor(jaxsd[".Dense_0.kernel.1"][0]) + sd["linears.0.weight_g"] = paddle.to_tensor(jaxsd[".Dense_0.kernel.0"][0]) + sd["linears.0.bias"] = paddle.to_tensor(jaxsd[".Dense_0.bias"][0]) + + sd["linears.1.weight_v"] = paddle.to_tensor(jaxsd[".Dense_1.kernel.1"][0]) + sd["linears.1.weight_g"] = paddle.to_tensor(jaxsd[".Dense_1.kernel.0"][0]) + sd["linears.1.bias"] = paddle.to_tensor(jaxsd[".Dense_1.bias"][0]) + + sd["linears.2.weight_v"] = paddle.to_tensor(jaxsd[".Dense_2.kernel.1"][0]) + sd["linears.2.weight_g"] = paddle.to_tensor(jaxsd[".Dense_2.kernel.0"][0]) + sd["linears.2.bias"] = paddle.to_tensor(jaxsd[".Dense_2.bias"][0]) + + sd["linears.3.weight_v"] = paddle.to_tensor(jaxsd[".Dense_3.kernel.1"][0]) + sd["linears.3.weight_g"] = paddle.to_tensor(jaxsd[".Dense_3.kernel.0"][0]) + sd["linears.3.bias"] = paddle.to_tensor(jaxsd[".Dense_3.bias"][0]) + + sd["last_fc.weight_v"] = paddle.to_tensor(jaxsd[".Dense_4.kernel.1"][0]) + sd["last_fc.weight_g"] = paddle.to_tensor(jaxsd[".Dense_4.kernel.0"][0]) + sd["last_fc.bias"] = paddle.to_tensor(jaxsd[".Dense_4.bias"][0]) + model.set_state_dict(sd) + + # set equation + equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)} + + # set constraint + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + + u0 = u_ref[0, :] # [nx, ] + + t0 = t_star[0] # float + t1 = t_star[-1] # float + + x0 = x_star[0] # float + x1 = x_star[-1] # float + + def gen_input_batch(): + tx = np.random.uniform( + [t0, x0], + [t1, x1], + (cfg.TRAIN.batch_size, 2), + ).astype(dtype) + return { + "t": np.sort(tx[:, 0:1], axis=0), + "x": tx[:, 1:2], + } + + def gen_label_batch(input_batch): + return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)} + + pde_constraint = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "ContinuousNamedArrayDataset", + "input": gen_input_batch, + "label": gen_label_batch, + }, + }, + output_expr=equation["AllenCahn"].equations, + loss=ppsci.loss.CausalMSELoss( + cfg.TRAIN.causal.n_chunks, "mean", tol=cfg.TRAIN.causal.tol + ), + name="PDE", + ) + + ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])} + ic_label = {"u": u0.reshape([-1, 1])} + ic = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "IterableNamedArrayDataset", + "input": ic_input, + "label": ic_label, + }, + }, + output_expr={"u": lambda out: out["u"]}, + loss=ppsci.loss.MSELoss("mean"), + name="IC", + ) + # wrap constraints together + constraint = { + pde_constraint.name: pde_constraint, + ic.name: ic, + } + + # set optimizer + lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay( + **cfg.TRAIN.lr_scheduler + )() + optimizer = ppsci.optimizer.Adam(lr_scheduler)(model) + + # set validator + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + eval_label = {"u": u_ref.reshape([-1, 1])} + u_validator = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": eval_data, + "label": eval_label, + }, + "batch_size": cfg.EVAL.batch_size, + }, + ppsci.loss.MSELoss("mean"), + {"u": lambda out: out["u"]}, + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="u_validator", + ) + validator = {u_validator.name: u_validator} + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + cfg.output_dir, + optimizer, + epochs=cfg.TRAIN.epochs, + iters_per_epoch=cfg.TRAIN.iters_per_epoch, + save_freq=cfg.TRAIN.save_freq, + log_freq=cfg.log_freq, + eval_during_train=True, + eval_freq=cfg.TRAIN.eval_freq, + equation=equation, + validator=validator, + pretrained_model_path=cfg.TRAIN.pretrained_model_path, + checkpoint_path=cfg.TRAIN.checkpoint_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + loss_aggregator=mtl.NTK( + model, + len(constraint), + cfg.TRAIN.ntk.update_freq, + # cfg.TRAIN.ntk.momentum, + ), + cfg=cfg, + ) + # train model + solver.train() + # evaluate after finished training + solver.eval() + # visualize prediction after finished training + u_pred = solver.predict( + eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True + )["u"] + u_pred = u_pred.reshape([len(t_star), len(x_star)]) + + # plot + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +def evaluate(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + + # set validator + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + eval_label = {"u": u_ref.reshape([-1, 1])} + u_validator = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": eval_data, + "label": eval_label, + }, + "batch_size": cfg.EVAL.batch_size, + }, + ppsci.loss.MSELoss("mean"), + {"u": lambda out: out["u"]}, + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="u_validator", + ) + validator = {u_validator.name: u_validator} + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + log_freq=cfg.log_freq, + validator=validator, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + + # evaluate after finished training + solver.eval() + # visualize prediction after finished training + u_pred = solver.predict( + eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True + )["u"] + u_pred = u_pred.reshape([len(t_star), len(x_star)]) + + # plot + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +def export(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, + ] + solver.export(input_spec, cfg.INFER.export_path, with_onnx=False) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + predictor = pinn_predictor.PINNPredictor(cfg) + data = sio.loadmat(cfg.DATA_PATH) + u_ref = data["usol"].astype(dtype) # (nt, nx) + t_star = data["t"].flatten().astype(dtype) # [nt, ] + x_star = data["x"].flatten().astype(dtype) # [nx, ] + tx_star = misc.cartesian_product(t_star, x_star).astype(dtype) + + input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]} + output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) + } + u_pred = output_dict["u"].reshape([len(t_star), len(x_star)]) + # mapping data to cfg.INFER.output_keys + + plot(t_star, x_star, u_ref, u_pred, cfg.output_dir) + + +@hydra.main(version_base=None, config_path="./conf", config_name="allen_cahn_sota.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) + else: + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml b/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml new file mode 100644 index 0000000000..7952f5f727 --- /dev/null +++ b/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml @@ -0,0 +1,96 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_allen_cahn_defalut_ntk/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - INFER.pretrained_model_path + - mode + - output_dir + - log_freq + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} +log_freq: 100 +use_tbd: false + +DATA_PATH: ./dataset/allen_cahn.mat + +# model settings +MODEL: + input_keys: [t, x] + output_keys: [u] + num_layers: 4 + hidden_size: 256 + activation: tanh + periods: + x: [2.0, False] + fourier: + dim: 256 + scale: 1.0 + random_weight: + mean: 0.5 + std: 0.1 + +# training settings +TRAIN: + epochs: 200 + iters_per_epoch: 1000 + save_freq: 10 + eval_during_train: true + eval_freq: 1 + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + learning_rate: 1.0e-3 + gamma: 0.9 + decay_steps: 2000 + by_epoch: false + batch_size: 4096 + pretrained_model_path: null + checkpoint_path: null + causal: + n_chunks: 32 + tol: 1.0 + ntk: + update_freq: 1000 + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true + batch_size: 4096 + +# inference settings +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_plain_pretrained.pdparams + export_path: ./inference/allen_cahn + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + onnx_path: ${INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + batch_size: 1024 diff --git a/examples/allen_cahn/conf/allen_cahn_sota.yaml b/examples/allen_cahn/conf/allen_cahn_sota.yaml new file mode 100644 index 0000000000..c6e4c587b0 --- /dev/null +++ b/examples/allen_cahn/conf/allen_cahn_sota.yaml @@ -0,0 +1,96 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_allen_cahn_sota/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - INFER.pretrained_model_path + - mode + - output_dir + - log_freq + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} +log_freq: 100 +use_tbd: false + +DATA_PATH: ./dataset/allen_cahn.mat + +# model settings +MODEL: + input_keys: [t, x] + output_keys: [u] + num_layers: 4 + hidden_size: 256 + activation: tanh + periods: + x: [2.0, False] + fourier: + dim: 256 + scale: 1.0 + random_weight: + mean: 0.5 + std: 0.1 + +# training settings +TRAIN: + epochs: 200 + iters_per_epoch: 1000 + save_freq: 10 + eval_during_train: true + eval_freq: 1 + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + learning_rate: 1.0e-3 + gamma: 0.9 + decay_steps: 2000 + by_epoch: false + batch_size: 4096 + pretrained_model_path: null + checkpoint_path: null + causal: + n_chunks: 32 + tol: 1.0 + ntk: + update_freq: 1000 + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true + batch_size: 4096 + +# inference settings +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_plain_pretrained.pdparams + export_path: ./inference/allen_cahn + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + onnx_path: ${INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + batch_size: 1024 diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index 42b93cc364..8239d35e9e 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -361,12 +361,19 @@ def __init__( weight_norm: bool = False, input_dim: Optional[int] = None, output_dim: Optional[int] = None, + periods: Optional[Dict[int, Tuple[float, bool]]] = None, + fourier: Optional[Dict[str, Union[float, int]]] = None, + random_weight: Optional[Dict[str, float]] = None, ): super().__init__() self.input_keys = input_keys self.output_keys = output_keys self.linears = [] self.acts = [] + self.periods = periods + self.fourier = fourier + if periods: + self.period_emb = PeriodEmbedding(periods) if isinstance(hidden_size, int): if not isinstance(num_layers, int): raise ValueError("num_layers should be an int") @@ -376,6 +383,17 @@ def __init__( # initialize FC layer(s) cur_size = len(self.input_keys) if input_dim is None else input_dim + if input_dim is None and periods: + # period embeded channel(s) will be doubled automatically + # if input_dim is not specified + cur_size += len(periods) + + if fourier: + self.fourier_emb = FourierEmbedding( + cur_size, fourier["dim"], fourier["scale"] + ) + cur_size = fourier["dim"] + self.embed_u = nn.Sequential( ( WeightNormLinear(cur_size, hidden_size[0]) @@ -402,11 +420,20 @@ def __init__( ) for i, _size in enumerate(hidden_size): - self.linears.append( - WeightNormLinear(cur_size, _size) - if weight_norm - else nn.Linear(cur_size, _size) - ) + if weight_norm: + self.linears.append(WeightNormLinear(cur_size, _size)) + elif random_weight: + self.linears.append( + RandomWeightFactorization( + cur_size, + _size, + mean=random_weight["mean"], + std=random_weight["std"], + ) + ) + else: + self.linears.append(nn.Linear(cur_size, _size)) + # initialize activation function self.acts.append( act_mod.get_activation(activation) @@ -457,7 +484,14 @@ def forward(self, x): if self._input_transform is not None: x = self._input_transform(x) + if self.periods: + x = self.period_emb(x) + y = self.concat_to_tensor(x, self.input_keys, axis=-1) + + if self.fourier: + y = self.fourier_emb(y) + y = self.forward_tensor(y) y = self.split_to_dict(y, self.output_keys, axis=-1) diff --git a/ppsci/loss/mtl/__init__.py b/ppsci/loss/mtl/__init__.py index 35f3b73d90..f33ba78b00 100644 --- a/ppsci/loss/mtl/__init__.py +++ b/ppsci/loss/mtl/__init__.py @@ -17,6 +17,7 @@ from ppsci.loss.mtl.agda import AGDA from ppsci.loss.mtl.base import LossAggregator from ppsci.loss.mtl.grad_norm import GradNorm +from ppsci.loss.mtl.ntk import NTK from ppsci.loss.mtl.pcgrad import PCGrad from ppsci.loss.mtl.relobralo import Relobralo from ppsci.loss.mtl.sum import Sum @@ -28,6 +29,7 @@ "PCGrad", "Relobralo", "Sum", + "NTK", ] diff --git a/ppsci/loss/mtl/ntk.py b/ppsci/loss/mtl/ntk.py new file mode 100644 index 0000000000..68239cd934 --- /dev/null +++ b/ppsci/loss/mtl/ntk.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List + +import paddle +from paddle import nn + +from ppsci.loss.mtl import base + + +class NTK(base.LossAggregator): + def __init__( + self, + model: nn.Layer, + num_losses: int = 1, + update_freq: int = 1000, + ) -> None: + super().__init__(model) + self.step = 0 + self.num_losses = num_losses + self.update_freq = update_freq + self.register_buffer("weight", paddle.ones([num_losses])) + + def compute_diag_ntk(self, params, batch): + ics_ntk = paddle.vmap(self.ntk_fn, (None, None, None, 0))( + self.u_net, params, self.t0, self.x_star + ) + + # Consider the effect of causal weights + use_causal = False + if use_causal: + # sort the time step for causal loss + sorted_batch = paddle.sort(batch[:, 0]) + batch = paddle.concat([sorted_batch, batch[:, 1].unsqueeze(1)], axis=1) + res_ntk = paddle.vmap(self.ntk_fn, (None, None, 0, 0))( + self.r_net, params, batch[:, 0], batch[:, 1] + ) + res_ntk = paddle.reshape( + res_ntk, [self.num_chunks, -1] + ) # shape: (num_chunks, -1) + res_ntk = paddle.mean( + res_ntk, axis=1 + ) # average convergence rate over each chunk + _, casual_weights = self.res_and_w(params, batch) + res_ntk = res_ntk * casual_weights # multiply by causal weights + else: + res_ntk = paddle.vmap(self.ntk_fn, (None, None, 0, 0))( + self.r_net, params, batch[:, 0], batch[:, 1] + ) + + ntk_dict = {"ics": ics_ntk, "res": res_ntk} + + return ntk_dict + + def _compute_weight(self, losses): + ntk_sum = 0 + ntk_value = [] + for loss in losses: + loss.backward(retain_graph=True) # NOTE: Keep graph for loss backward + with paddle.no_grad(): + grad = paddle.concat( + [ + p.grad.reshape([-1]) + for p in self.model.parameters() + if p.grad is not None + ] + ) + ntk_value.append( + paddle.sqrt( + paddle.sum(grad.detach() ** 2), + ) + ) + + ntk_sum += paddle.sum(paddle.stack(ntk_value, axis=0)) + ntk_weight = [(ntk_sum / x) for x in ntk_value] + + return ntk_weight + + def __call__(self, losses: List["paddle.Tensor"], step: int = 0) -> "paddle.Tensor": + assert len(losses) == self.num_losses, ( + f"Length of given losses({len(losses)}) should be equal to " + f"num_losses({self.num_losses})." + ) + self.step = step + + # compute current loss with moving weights + loss = self.weight[0] * losses[0] + for i in range(1, len(losses)): + loss += self.weight[i] * losses[i] + + # update moving weights every 'update_freq' steps + if self.step % self.update_freq == 0: + computed_weight = self._compute_weight(losses) + for i in range(self.num_losses): + self.weight[i].set_value(computed_weight[i]) + + return loss From 4c14d417beb43ba9626d76374645a073ed23338a Mon Sep 17 00:00:00 2001 From: xusuyong <2209245477@qq.com> Date: Mon, 6 May 2024 16:30:56 +0800 Subject: [PATCH 2/7] update code --- ppsci/loss/mtl/ntk.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/ppsci/loss/mtl/ntk.py b/ppsci/loss/mtl/ntk.py index 68239cd934..523f94a4ab 100644 --- a/ppsci/loss/mtl/ntk.py +++ b/ppsci/loss/mtl/ntk.py @@ -35,37 +35,6 @@ def __init__( self.update_freq = update_freq self.register_buffer("weight", paddle.ones([num_losses])) - def compute_diag_ntk(self, params, batch): - ics_ntk = paddle.vmap(self.ntk_fn, (None, None, None, 0))( - self.u_net, params, self.t0, self.x_star - ) - - # Consider the effect of causal weights - use_causal = False - if use_causal: - # sort the time step for causal loss - sorted_batch = paddle.sort(batch[:, 0]) - batch = paddle.concat([sorted_batch, batch[:, 1].unsqueeze(1)], axis=1) - res_ntk = paddle.vmap(self.ntk_fn, (None, None, 0, 0))( - self.r_net, params, batch[:, 0], batch[:, 1] - ) - res_ntk = paddle.reshape( - res_ntk, [self.num_chunks, -1] - ) # shape: (num_chunks, -1) - res_ntk = paddle.mean( - res_ntk, axis=1 - ) # average convergence rate over each chunk - _, casual_weights = self.res_and_w(params, batch) - res_ntk = res_ntk * casual_weights # multiply by causal weights - else: - res_ntk = paddle.vmap(self.ntk_fn, (None, None, 0, 0))( - self.r_net, params, batch[:, 0], batch[:, 1] - ) - - ntk_dict = {"ics": ics_ntk, "res": res_ntk} - - return ntk_dict - def _compute_weight(self, losses): ntk_sum = 0 ntk_value = [] From c58f3398fb2dc82385e20e081d8971bf040d73b5 Mon Sep 17 00:00:00 2001 From: xusuyong <2209245477@qq.com> Date: Mon, 6 May 2024 17:03:23 +0800 Subject: [PATCH 3/7] add allen cahn ntk --- examples/allen_cahn/allen_cahn_defalut_ntk.py | 25 ------------------ examples/allen_cahn/allen_cahn_sota.py | 26 ------------------- .../conf/allen_cahn_defalut_ntk.yaml | 2 +- examples/allen_cahn/conf/allen_cahn_sota.yaml | 2 +- 4 files changed, 2 insertions(+), 53 deletions(-) diff --git a/examples/allen_cahn/allen_cahn_defalut_ntk.py b/examples/allen_cahn/allen_cahn_defalut_ntk.py index 95e2909ff4..ec8e073151 100644 --- a/examples/allen_cahn/allen_cahn_defalut_ntk.py +++ b/examples/allen_cahn/allen_cahn_defalut_ntk.py @@ -63,31 +63,6 @@ def train(cfg: DictConfig): # set model model = ppsci.arch.MLP(**cfg.MODEL) - jaxsd = np.load("allen_cahn_init_weight.npz") - - sd = {} - sd["fourier_emb.kernel"] = paddle.to_tensor(jaxsd[".FourierEmbs_0.kernel"][0]) - sd["linears.0.weight_v"] = paddle.to_tensor(jaxsd[".Dense_0.kernel.1"][0]) - sd["linears.0.weight_g"] = paddle.to_tensor(jaxsd[".Dense_0.kernel.0"][0]) - sd["linears.0.bias"] = paddle.to_tensor(jaxsd[".Dense_0.bias"][0]) - - sd["linears.1.weight_v"] = paddle.to_tensor(jaxsd[".Dense_1.kernel.1"][0]) - sd["linears.1.weight_g"] = paddle.to_tensor(jaxsd[".Dense_1.kernel.0"][0]) - sd["linears.1.bias"] = paddle.to_tensor(jaxsd[".Dense_1.bias"][0]) - - sd["linears.2.weight_v"] = paddle.to_tensor(jaxsd[".Dense_2.kernel.1"][0]) - sd["linears.2.weight_g"] = paddle.to_tensor(jaxsd[".Dense_2.kernel.0"][0]) - sd["linears.2.bias"] = paddle.to_tensor(jaxsd[".Dense_2.bias"][0]) - - sd["linears.3.weight_v"] = paddle.to_tensor(jaxsd[".Dense_3.kernel.1"][0]) - sd["linears.3.weight_g"] = paddle.to_tensor(jaxsd[".Dense_3.kernel.0"][0]) - sd["linears.3.bias"] = paddle.to_tensor(jaxsd[".Dense_3.bias"][0]) - - sd["last_fc.weight_v"] = paddle.to_tensor(jaxsd[".Dense_4.kernel.1"][0]) - sd["last_fc.weight_g"] = paddle.to_tensor(jaxsd[".Dense_4.kernel.0"][0]) - sd["last_fc.bias"] = paddle.to_tensor(jaxsd[".Dense_4.bias"][0]) - model.set_state_dict(sd) - # set equation equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)} diff --git a/examples/allen_cahn/allen_cahn_sota.py b/examples/allen_cahn/allen_cahn_sota.py index 3226431014..6bc16ea907 100644 --- a/examples/allen_cahn/allen_cahn_sota.py +++ b/examples/allen_cahn/allen_cahn_sota.py @@ -15,7 +15,6 @@ from ppsci.loss import mtl from ppsci.utils import misc -# paddle.device.set_device("cpu") dtype = paddle.get_default_dtype() @@ -64,31 +63,6 @@ def train(cfg: DictConfig): # set model model = ppsci.arch.ModifiedMLP(**cfg.MODEL) - jaxsd = np.load("allen_cahn_init_weight.npz") - - sd = {} - sd["fourier_emb.kernel"] = paddle.to_tensor(jaxsd[".FourierEmbs_0.kernel"][0]) - sd["linears.0.weight_v"] = paddle.to_tensor(jaxsd[".Dense_0.kernel.1"][0]) - sd["linears.0.weight_g"] = paddle.to_tensor(jaxsd[".Dense_0.kernel.0"][0]) - sd["linears.0.bias"] = paddle.to_tensor(jaxsd[".Dense_0.bias"][0]) - - sd["linears.1.weight_v"] = paddle.to_tensor(jaxsd[".Dense_1.kernel.1"][0]) - sd["linears.1.weight_g"] = paddle.to_tensor(jaxsd[".Dense_1.kernel.0"][0]) - sd["linears.1.bias"] = paddle.to_tensor(jaxsd[".Dense_1.bias"][0]) - - sd["linears.2.weight_v"] = paddle.to_tensor(jaxsd[".Dense_2.kernel.1"][0]) - sd["linears.2.weight_g"] = paddle.to_tensor(jaxsd[".Dense_2.kernel.0"][0]) - sd["linears.2.bias"] = paddle.to_tensor(jaxsd[".Dense_2.bias"][0]) - - sd["linears.3.weight_v"] = paddle.to_tensor(jaxsd[".Dense_3.kernel.1"][0]) - sd["linears.3.weight_g"] = paddle.to_tensor(jaxsd[".Dense_3.kernel.0"][0]) - sd["linears.3.bias"] = paddle.to_tensor(jaxsd[".Dense_3.bias"][0]) - - sd["last_fc.weight_v"] = paddle.to_tensor(jaxsd[".Dense_4.kernel.1"][0]) - sd["last_fc.weight_g"] = paddle.to_tensor(jaxsd[".Dense_4.kernel.0"][0]) - sd["last_fc.bias"] = paddle.to_tensor(jaxsd[".Dense_4.bias"][0]) - model.set_state_dict(sd) - # set equation equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)} diff --git a/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml b/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml index 7952f5f727..a0f7ae4e94 100644 --- a/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml +++ b/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml @@ -79,7 +79,7 @@ EVAL: # inference settings INFER: - pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_plain_pretrained.pdparams + pretrained_model_path: null export_path: ./inference/allen_cahn pdmodel_path: ${INFER.export_path}.pdmodel pdpiparams_path: ${INFER.export_path}.pdiparams diff --git a/examples/allen_cahn/conf/allen_cahn_sota.yaml b/examples/allen_cahn/conf/allen_cahn_sota.yaml index c6e4c587b0..8b8a8f675e 100644 --- a/examples/allen_cahn/conf/allen_cahn_sota.yaml +++ b/examples/allen_cahn/conf/allen_cahn_sota.yaml @@ -79,7 +79,7 @@ EVAL: # inference settings INFER: - pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_plain_pretrained.pdparams + pretrained_model_path: null export_path: ./inference/allen_cahn pdmodel_path: ${INFER.export_path}.pdmodel pdpiparams_path: ${INFER.export_path}.pdiparams From 321e7f2f4145b7fb38bcc5531ae8949c9caffc57 Mon Sep 17 00:00:00 2001 From: xusuyong <2209245477@qq.com> Date: Mon, 6 May 2024 17:15:12 +0800 Subject: [PATCH 4/7] update code --- examples/allen_cahn/conf/allen_cahn_sota.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/allen_cahn/conf/allen_cahn_sota.yaml b/examples/allen_cahn/conf/allen_cahn_sota.yaml index 8b8a8f675e..37a0ca1409 100644 --- a/examples/allen_cahn/conf/allen_cahn_sota.yaml +++ b/examples/allen_cahn/conf/allen_cahn_sota.yaml @@ -43,14 +43,14 @@ MODEL: x: [2.0, False] fourier: dim: 256 - scale: 1.0 + scale: 2.0 random_weight: - mean: 0.5 + mean: 1.0 std: 0.1 # training settings TRAIN: - epochs: 200 + epochs: 300 iters_per_epoch: 1000 save_freq: 10 eval_during_train: true @@ -60,9 +60,9 @@ TRAIN: iters_per_epoch: ${TRAIN.iters_per_epoch} learning_rate: 1.0e-3 gamma: 0.9 - decay_steps: 2000 + decay_steps: 5000 by_epoch: false - batch_size: 4096 + batch_size: 8192 pretrained_model_path: null checkpoint_path: null causal: From bf04e9719407d4467c3f094e6969a993f0da7181 Mon Sep 17 00:00:00 2001 From: xusuyong <2209245477@qq.com> Date: Mon, 6 May 2024 17:21:16 +0800 Subject: [PATCH 5/7] update code --- examples/allen_cahn/allen_cahn_defalut_ntk.py | 1 - examples/allen_cahn/allen_cahn_sota.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/allen_cahn/allen_cahn_defalut_ntk.py b/examples/allen_cahn/allen_cahn_defalut_ntk.py index ec8e073151..28e4f7da65 100644 --- a/examples/allen_cahn/allen_cahn_defalut_ntk.py +++ b/examples/allen_cahn/allen_cahn_defalut_ntk.py @@ -176,7 +176,6 @@ def gen_label_batch(input_batch): model, len(constraint), cfg.TRAIN.ntk.update_freq, - # cfg.TRAIN.ntk.momentum, ), cfg=cfg, ) diff --git a/examples/allen_cahn/allen_cahn_sota.py b/examples/allen_cahn/allen_cahn_sota.py index 6bc16ea907..1fb4b7055b 100644 --- a/examples/allen_cahn/allen_cahn_sota.py +++ b/examples/allen_cahn/allen_cahn_sota.py @@ -176,7 +176,6 @@ def gen_label_batch(input_batch): model, len(constraint), cfg.TRAIN.ntk.update_freq, - # cfg.TRAIN.ntk.momentum, ), cfg=cfg, ) From 72837f48c424527ad833577a7ce4bc437889a57c Mon Sep 17 00:00:00 2001 From: xusuyong <2209245477@qq.com> Date: Mon, 6 May 2024 17:28:23 +0800 Subject: [PATCH 6/7] update code --- ppsci/loss/mse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppsci/loss/mse.py b/ppsci/loss/mse.py index 8863dc8c0c..6553245c44 100644 --- a/ppsci/loss/mse.py +++ b/ppsci/loss/mse.py @@ -117,7 +117,7 @@ class CausalMSELoss(base.Loss): Examples: >>> import paddle - >>> from ppsci.loss import MSELoss + >>> from ppsci.loss import CausalMSELoss >>> output_dict = {'u': paddle.to_tensor([[0.5, 0.9, 1.0], [1.1, -1.3, 0.0]])} >>> label_dict = {'u': paddle.to_tensor([[-1.8, 1.0, -0.1], [-0.2, 2.5, 2.0]])} From c2c70222664accb159558795554c2eadecb83efc Mon Sep 17 00:00:00 2001 From: xusuyong <2209245477@qq.com> Date: Wed, 29 May 2024 19:28:39 +0800 Subject: [PATCH 7/7] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml b/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml index a0f7ae4e94..0a4eac375c 100644 --- a/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml +++ b/examples/allen_cahn/conf/allen_cahn_defalut_ntk.yaml @@ -43,9 +43,9 @@ MODEL: x: [2.0, False] fourier: dim: 256 - scale: 1.0 + scale: 2.0 random_weight: - mean: 0.5 + mean: 1.0 std: 0.1 # training settings @@ -60,7 +60,7 @@ TRAIN: iters_per_epoch: ${TRAIN.iters_per_epoch} learning_rate: 1.0e-3 gamma: 0.9 - decay_steps: 2000 + decay_steps: 5000 by_epoch: false batch_size: 4096 pretrained_model_path: null