diff --git a/examples/fpde/conf/fractional_poisson_2d.yaml b/examples/fpde/conf/fractional_poisson_2d.yaml new file mode 100644 index 0000000000..c0b657b82c --- /dev/null +++ b/examples/fpde/conf/fractional_poisson_2d.yaml @@ -0,0 +1,69 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_fractional_poisson_2d/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} +log_freq: 100 + +ALPHA: 1.8 +NPOINT_INTERIOR: 100 +NPOINT_BC: 1 +NPOINT_EVAL: 1000 + +# model settings +MODEL: + input_keys: ["x", "y"] + output_keys: ["u"] + num_layers: 4 + hidden_size: 20 + activation: "tanh" + +# training settings +TRAIN: + epochs: 20000 + iters_per_epoch: 1 + save_freq: 100 + eval_during_train: true + eval_freq: 1000 + learning_rate: 0.001 + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true + batch_size: + sup_validator: 128 diff --git a/examples/fpde/fractional_poisson_2d.py b/examples/fpde/fractional_poisson_2d.py index 6acb518f5f..86d88bb883 100644 --- a/examples/fpde/fractional_poisson_2d.py +++ b/examples/fpde/fractional_poisson_2d.py @@ -19,32 +19,50 @@ from typing import Tuple from typing import Union +import hydra import numpy as np import paddle from matplotlib import cm from matplotlib import pyplot as plt +from omegaconf import DictConfig import ppsci -from ppsci.utils import config -from ppsci.utils import logger -if __name__ == "__main__": - args = config.parse_args() - # set random seed for reproducibility - ppsci.utils.misc.set_random_seed(42) - # set training hyper-parameters - EPOCHS = 20000 if not args.epochs else args.epochs - ITERS_PER_EPOCH = 1 +def plot(x, y, input_data, output_data, label_data): + fig = plt.figure() + # plot prediction + ax1 = fig.add_subplot(121, projection="3d") + surf1 = ax1.plot_surface( + x, y, output_data["u"], cmap=cm.jet, linewidth=0, antialiased=False + ) + ax1.set_zlim(0, 1.2) + ax1.set_xlabel(r"$x$") + ax1.set_ylabel(r"$y$") + ax1.set_zlabel(r"$z$") + ax1.set_title(r"$u(x,y), label$") + fig.colorbar(surf1, ax=ax1, aspect=5, orientation="horizontal") - # set output directory - OUTPUT_DIR = ( - "./output_fractional_poisson_2d" if not args.output_dir else args.output_dir + # plot label + ax2 = fig.add_subplot(122, projection="3d") + surf2 = ax2.plot_surface( + x, y, label_data, cmap=cm.jet, linewidth=0, antialiased=False ) - logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + ax2.set_zlim(0, 1.2) + ax2.set_xlabel("x") + ax2.set_ylabel("y") + ax2.set_zlabel("z") + ax2.set_title(r"$u(x,y), prediction$") + + # Add a color bar which maps values to colors. + fig.colorbar(surf2, ax=ax2, aspect=5, orientation="horizontal") + fig.subplots_adjust(wspace=0.5, hspace=0.5) + plt.savefig("fractional_poisson_2d_result.png", dpi=400) + +def train(cfg: DictConfig): # set model - model = ppsci.arch.MLP(("x", "y"), ("u",), 4, 20) + model = ppsci.arch.MLP(**cfg.MODEL) def output_transform(in_, out): return {"u": (1 - (in_["x"] ** 2 + in_["y"] ** 2)) * out["u"]} @@ -55,19 +73,19 @@ def output_transform(in_, out): geom = {"disk": ppsci.geometry.Disk((0, 0), 1)} # set equation - ALPHA = 1.8 - equation = {"fpde": ppsci.equation.FractionalPoisson(ALPHA, geom["disk"], [8, 100])} + equation = { + "fpde": ppsci.equation.FractionalPoisson(cfg.ALPHA, geom["disk"], [8, 100]) + } # set constraint - NPOINT_INTERIOR = 100 - NPOINT_BC = 1 - def u_solution_func( out: Dict[str, Union[paddle.Tensor, np.ndarray]] ) -> Union[paddle.Tensor, np.ndarray]: if isinstance(out["x"], paddle.Tensor): - return paddle.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (1 + ALPHA / 2) - return np.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (1 + ALPHA / 2) + return paddle.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** ( + 1 + cfg.ALPHA / 2 + ) + return np.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (1 + cfg.ALPHA / 2) # set transform for input data def input_data_fpde_transform( @@ -114,8 +132,8 @@ def input_data_fpde_transform( }, ), }, - "batch_size": NPOINT_INTERIOR, - "iters_per_epoch": ITERS_PER_EPOCH, + "batch_size": cfg.NPOINT_INTERIOR, + "iters_per_epoch": cfg.TRAIN.iters_per_epoch, }, ppsci.loss.MSELoss("mean"), random="Hammersley", @@ -128,11 +146,10 @@ def input_data_fpde_transform( geom["disk"], { "dataset": {"name": "IterableNamedArrayDataset"}, - "batch_size": NPOINT_BC, - "iters_per_epoch": ITERS_PER_EPOCH, + "batch_size": cfg.NPOINT_BC, + "iters_per_epoch": cfg.TRAIN.iters_per_epoch, }, ppsci.loss.MSELoss("mean"), - random="Hammersley", criteria=lambda x, y: np.isclose(x, -1), name="BC", ) @@ -143,18 +160,16 @@ def input_data_fpde_transform( } # set optimizer - optimizer = ppsci.optimizer.Adam(1e-3)(model) + optimizer = ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model) # set validator - NPOINT_EVAL = 1000 - EVAL_FREQ = 1000 l2rel_metric = ppsci.validate.GeometryValidator( {"u": lambda out: out["u"]}, {"u": u_solution_func}, geom["disk"], { "dataset": "IterableNamedArrayDataset", - "total_size": NPOINT_EVAL, + "total_size": cfg.NPOINT_EVAL, }, ppsci.loss.MSELoss(), metric={"L2Rel": ppsci.metric.L2Rel()}, @@ -166,16 +181,10 @@ def input_data_fpde_transform( solver = ppsci.solver.Solver( model, constraint, - OUTPUT_DIR, - optimizer, - epochs=EPOCHS, - iters_per_epoch=ITERS_PER_EPOCH, - eval_during_train=True, - eval_freq=EVAL_FREQ, + optimizer=optimizer, equation=equation, - geom=geom, validator=validator, - eval_with_no_grad=True, + cfg=cfg, ) # train model solver.train() @@ -194,32 +203,80 @@ def input_data_fpde_transform( label_data = u_solution_func(input_data).reshape([x.shape[0], -1]) output_data = solver.predict(input_data, return_numpy=True) output_data = {k: v.reshape([x.shape[0], -1]) for k, v in output_data.items()} + plot(x, y, input_data, output_data, label_data) - fig = plt.figure() - # plot prediction - ax1 = fig.add_subplot(121, projection="3d") - surf1 = ax1.plot_surface( - x, y, output_data["u"], cmap=cm.jet, linewidth=0, antialiased=False + +def evaluate(cfg: DictConfig): + # load model + model = ppsci.load_model(cfg.pretrained_model_path) + # set geometry + geom = { + "disk": ppsci.geometry.Disk(np.array([0, 0]), np.array([1]), np.array([[0]])), + } + + def u_solution_func( + out: Dict[str, Union[paddle.Tensor, np.ndarray]] + ) -> Union[paddle.Tensor, np.ndarray]: + if isinstance(out["x"], paddle.Tensor): + return paddle.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** ( + 1 + cfg.ALPHA / 2 + ) + return np.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (1 + cfg.ALPHA / 2) + + # set validator + l2rel_metric = ppsci.validate.GeometryValidator( + {"u": lambda out: out["u"]}, + {"u": u_solution_func}, + geom["disk"], + { + "dataset": "IterableNamedArrayDataset", + "total_size": cfg.NPOINT_EVAL, + }, + ppsci.loss.MSELoss(), + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="L2Rel_Metric", ) - ax1.set_zlim(0, 1.2) - ax1.set_xlabel(r"$x$") - ax1.set_ylabel(r"$y$") - ax1.set_zlabel(r"$z$") - ax1.set_title(r"$u(x,y), label$") - fig.colorbar(surf1, ax=ax1, aspect=5, orientation="horizontal") + validator = {l2rel_metric.name: l2rel_metric} - # plot label - ax2 = fig.add_subplot(122, projection="3d") - surf2 = ax2.plot_surface( - x, y, label_data, cmap=cm.jet, linewidth=0, antialiased=False + # initialize solver + solver = ppsci.solver.Solver( + model, + validator=validator, + cfg=cfg, ) - ax2.set_zlim(0, 1.2) - ax2.set_xlabel("x") - ax2.set_ylabel("y") - ax2.set_zlabel("z") - ax2.set_title(r"$u(x,y), prediction$") + # train model + solver.train() - # Add a color bar which maps values to colors. - fig.colorbar(surf2, ax=ax2, aspect=5, orientation="horizontal") - fig.subplots_adjust(wspace=0.5, hspace=0.5) - plt.savefig("fractional_poisson_2d_result.png", dpi=400) + # visualize prediction after finished training + theta = np.arange(0, 2 * math.pi, 0.04) + rho = np.arange(0, 1, 0.005) + mt, mr = np.meshgrid(theta, rho) + x = mr * np.cos(mt) + y = mr * np.sin(mt) + + input_data = { + "x": x.reshape([-1, 1]), + "y": y.reshape([-1, 1]), + } + + label_data = u_solution_func(input_data).reshape([x.shape[0], -1]) + output_data = solver.predict(input_data, return_numpy=True) + output_data = {k: v.reshape([x.shape[0], -1]) for k, v in output_data.items()} + + plot(x, y, input_data, output_data, label_data) + + +@hydra.main( + version_base=None, config_path="./conf", config_name="fractional_poisson_2d.yaml" +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() diff --git a/ppsci/geometry/geometry.py b/ppsci/geometry/geometry.py index 1dd9171f2a..a06168488a 100644 --- a/ppsci/geometry/geometry.py +++ b/ppsci/geometry/geometry.py @@ -316,7 +316,7 @@ def sample_boundary( if len(points) > 0: _nsuc += 1 - if _ntry >= 1000 and _nsuc == 0: + if _ntry >= 10000 and _nsuc == 0: raise ValueError( "Sample boundary points failed, " "please check correctness of geometry and given criteria."