Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions examples/fpde/conf/fractional_poisson_2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- _self_

hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_fractional_poisson_2d/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchaned
config:
override_dirname:
exclude_keys:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- mode
- output_dir
- log_freq
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}
log_freq: 100

ALPHA: 1.8
NPOINT_INTERIOR: 100
NPOINT_BC: 1
NPOINT_EVAL: 1000

# model settings
MODEL:
input_keys: ["x", "y"]
output_keys: ["u"]
num_layers: 4
hidden_size: 20
activation: "tanh"

# training settings
TRAIN:
epochs: 20000
iters_per_epoch: 1
save_freq: 100
eval_during_train: true
eval_freq: 1000
learning_rate: 0.001
pretrained_model_path: null
checkpoint_path: null

# evaluation settings
EVAL:
pretrained_model_path: null
eval_with_no_grad: true
batch_size:
sup_validator: 128
181 changes: 119 additions & 62 deletions examples/fpde/fractional_poisson_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,50 @@
from typing import Tuple
from typing import Union

import hydra
import numpy as np
import paddle
from matplotlib import cm
from matplotlib import pyplot as plt
from omegaconf import DictConfig

import ppsci
from ppsci.utils import config
from ppsci.utils import logger

if __name__ == "__main__":
args = config.parse_args()
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)

# set training hyper-parameters
EPOCHS = 20000 if not args.epochs else args.epochs
ITERS_PER_EPOCH = 1
def plot(x, y, input_data, output_data, label_data):
fig = plt.figure()
# plot prediction
ax1 = fig.add_subplot(121, projection="3d")
surf1 = ax1.plot_surface(
x, y, output_data["u"], cmap=cm.jet, linewidth=0, antialiased=False
)
ax1.set_zlim(0, 1.2)
ax1.set_xlabel(r"$x$")
ax1.set_ylabel(r"$y$")
ax1.set_zlabel(r"$z$")
ax1.set_title(r"$u(x,y), label$")
fig.colorbar(surf1, ax=ax1, aspect=5, orientation="horizontal")

# set output directory
OUTPUT_DIR = (
"./output_fractional_poisson_2d" if not args.output_dir else args.output_dir
# plot label
ax2 = fig.add_subplot(122, projection="3d")
surf2 = ax2.plot_surface(
x, y, label_data, cmap=cm.jet, linewidth=0, antialiased=False
)
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")
ax2.set_zlim(0, 1.2)
ax2.set_xlabel("x")
ax2.set_ylabel("y")
ax2.set_zlabel("z")
ax2.set_title(r"$u(x,y), prediction$")

# Add a color bar which maps values to colors.
fig.colorbar(surf2, ax=ax2, aspect=5, orientation="horizontal")
fig.subplots_adjust(wspace=0.5, hspace=0.5)
plt.savefig("fractional_poisson_2d_result.png", dpi=400)


def train(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(("x", "y"), ("u",), 4, 20)
model = ppsci.arch.MLP(**cfg.MODEL)

def output_transform(in_, out):
return {"u": (1 - (in_["x"] ** 2 + in_["y"] ** 2)) * out["u"]}
Expand All @@ -55,19 +73,19 @@ def output_transform(in_, out):
geom = {"disk": ppsci.geometry.Disk((0, 0), 1)}

# set equation
ALPHA = 1.8
equation = {"fpde": ppsci.equation.FractionalPoisson(ALPHA, geom["disk"], [8, 100])}
equation = {
"fpde": ppsci.equation.FractionalPoisson(cfg.ALPHA, geom["disk"], [8, 100])
}

# set constraint
NPOINT_INTERIOR = 100
NPOINT_BC = 1

def u_solution_func(
out: Dict[str, Union[paddle.Tensor, np.ndarray]]
) -> Union[paddle.Tensor, np.ndarray]:
if isinstance(out["x"], paddle.Tensor):
return paddle.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (1 + ALPHA / 2)
return np.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (1 + ALPHA / 2)
return paddle.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (
1 + cfg.ALPHA / 2
)
return np.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (1 + cfg.ALPHA / 2)

# set transform for input data
def input_data_fpde_transform(
Expand Down Expand Up @@ -114,8 +132,8 @@ def input_data_fpde_transform(
},
),
},
"batch_size": NPOINT_INTERIOR,
"iters_per_epoch": ITERS_PER_EPOCH,
"batch_size": cfg.NPOINT_INTERIOR,
"iters_per_epoch": cfg.TRAIN.iters_per_epoch,
},
ppsci.loss.MSELoss("mean"),
random="Hammersley",
Expand All @@ -128,11 +146,10 @@ def input_data_fpde_transform(
geom["disk"],
{
"dataset": {"name": "IterableNamedArrayDataset"},
"batch_size": NPOINT_BC,
"iters_per_epoch": ITERS_PER_EPOCH,
"batch_size": cfg.NPOINT_BC,
"iters_per_epoch": cfg.TRAIN.iters_per_epoch,
},
ppsci.loss.MSELoss("mean"),
random="Hammersley",
criteria=lambda x, y: np.isclose(x, -1),
name="BC",
)
Expand All @@ -143,18 +160,16 @@ def input_data_fpde_transform(
}

# set optimizer
optimizer = ppsci.optimizer.Adam(1e-3)(model)
optimizer = ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model)

# set validator
NPOINT_EVAL = 1000
EVAL_FREQ = 1000
l2rel_metric = ppsci.validate.GeometryValidator(
{"u": lambda out: out["u"]},
{"u": u_solution_func},
geom["disk"],
{
"dataset": "IterableNamedArrayDataset",
"total_size": NPOINT_EVAL,
"total_size": cfg.NPOINT_EVAL,
},
ppsci.loss.MSELoss(),
metric={"L2Rel": ppsci.metric.L2Rel()},
Expand All @@ -166,16 +181,10 @@ def input_data_fpde_transform(
solver = ppsci.solver.Solver(
model,
constraint,
OUTPUT_DIR,
optimizer,
epochs=EPOCHS,
iters_per_epoch=ITERS_PER_EPOCH,
eval_during_train=True,
eval_freq=EVAL_FREQ,
optimizer=optimizer,
equation=equation,
geom=geom,
validator=validator,
eval_with_no_grad=True,
cfg=cfg,
)
# train model
solver.train()
Expand All @@ -194,32 +203,80 @@ def input_data_fpde_transform(
label_data = u_solution_func(input_data).reshape([x.shape[0], -1])
output_data = solver.predict(input_data, return_numpy=True)
output_data = {k: v.reshape([x.shape[0], -1]) for k, v in output_data.items()}
plot(x, y, input_data, output_data, label_data)

fig = plt.figure()
# plot prediction
ax1 = fig.add_subplot(121, projection="3d")
surf1 = ax1.plot_surface(
x, y, output_data["u"], cmap=cm.jet, linewidth=0, antialiased=False

def evaluate(cfg: DictConfig):
# load model
model = ppsci.load_model(cfg.pretrained_model_path)
# set geometry
geom = {
"disk": ppsci.geometry.Disk(np.array([0, 0]), np.array([1]), np.array([[0]])),
}

def u_solution_func(
out: Dict[str, Union[paddle.Tensor, np.ndarray]]
) -> Union[paddle.Tensor, np.ndarray]:
if isinstance(out["x"], paddle.Tensor):
return paddle.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (
1 + cfg.ALPHA / 2
)
return np.abs(1 - (out["x"] ** 2 + out["y"] ** 2)) ** (1 + cfg.ALPHA / 2)

# set validator
l2rel_metric = ppsci.validate.GeometryValidator(
{"u": lambda out: out["u"]},
{"u": u_solution_func},
geom["disk"],
{
"dataset": "IterableNamedArrayDataset",
"total_size": cfg.NPOINT_EVAL,
},
ppsci.loss.MSELoss(),
metric={"L2Rel": ppsci.metric.L2Rel()},
name="L2Rel_Metric",
)
ax1.set_zlim(0, 1.2)
ax1.set_xlabel(r"$x$")
ax1.set_ylabel(r"$y$")
ax1.set_zlabel(r"$z$")
ax1.set_title(r"$u(x,y), label$")
fig.colorbar(surf1, ax=ax1, aspect=5, orientation="horizontal")
validator = {l2rel_metric.name: l2rel_metric}

# plot label
ax2 = fig.add_subplot(122, projection="3d")
surf2 = ax2.plot_surface(
x, y, label_data, cmap=cm.jet, linewidth=0, antialiased=False
# initialize solver
solver = ppsci.solver.Solver(
model,
validator=validator,
cfg=cfg,
)
ax2.set_zlim(0, 1.2)
ax2.set_xlabel("x")
ax2.set_ylabel("y")
ax2.set_zlabel("z")
ax2.set_title(r"$u(x,y), prediction$")
# train model
solver.train()

# Add a color bar which maps values to colors.
fig.colorbar(surf2, ax=ax2, aspect=5, orientation="horizontal")
fig.subplots_adjust(wspace=0.5, hspace=0.5)
plt.savefig("fractional_poisson_2d_result.png", dpi=400)
# visualize prediction after finished training
theta = np.arange(0, 2 * math.pi, 0.04)
rho = np.arange(0, 1, 0.005)
mt, mr = np.meshgrid(theta, rho)
x = mr * np.cos(mt)
y = mr * np.sin(mt)

input_data = {
"x": x.reshape([-1, 1]),
"y": y.reshape([-1, 1]),
}

label_data = u_solution_func(input_data).reshape([x.shape[0], -1])
output_data = solver.predict(input_data, return_numpy=True)
output_data = {k: v.reshape([x.shape[0], -1]) for k, v in output_data.items()}

plot(x, y, input_data, output_data, label_data)


@hydra.main(
version_base=None, config_path="./conf", config_name="fractional_poisson_2d.yaml"
)
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion ppsci/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def sample_boundary(
if len(points) > 0:
_nsuc += 1

if _ntry >= 1000 and _nsuc == 0:
if _ntry >= 10000 and _nsuc == 0:
raise ValueError(
"Sample boundary points failed, "
"please check correctness of geometry and given criteria."
Expand Down