Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
aedbfc1
allow empty optimizer when saving checkpoint
HydrogenSulfate Apr 9, 2024
e97886d
add model averaging module
HydrogenSulfate Apr 9, 2024
381c702
fix return dtype inconsistency with global dtype
HydrogenSulfate Apr 9, 2024
cf8b742
use python func instead of sympy function for pow(u,3) get a bit poor…
HydrogenSulfate Apr 10, 2024
d6f4b2b
refine AllenCahn docstring
HydrogenSulfate Apr 10, 2024
39bb505
support save and load for average model module
HydrogenSulfate Apr 10, 2024
265b7bb
add 3 ema unitests
HydrogenSulfate Apr 10, 2024
f8ce9f8
update 2023 to 2024
HydrogenSulfate Apr 10, 2024
5f61f8a
add ema config pydantic scheme
HydrogenSulfate Apr 10, 2024
13ea6c1
add avg_range for SWA
HydrogenSulfate Apr 10, 2024
c1a44fe
update field_validator for swa and ema
HydrogenSulfate Apr 10, 2024
e5fca6a
support period embedding for MLP
HydrogenSulfate Apr 10, 2024
5c88550
Keep non-float data when reading file
HydrogenSulfate Apr 10, 2024
4ce144f
Merge branch 'add_allen_cahn' into add_period_layer
HydrogenSulfate Apr 10, 2024
1be0892
update ema and save_load, printer and eval, solver module code
HydrogenSulfate Apr 10, 2024
1e77b75
add allen_cahn example
HydrogenSulfate Apr 11, 2024
21a7278
refine code
HydrogenSulfate Apr 11, 2024
6a0b36a
save buffer and non-grad required params in ema
HydrogenSulfate Apr 11, 2024
5d2de5f
add unitest for ema with buffer
HydrogenSulfate Apr 11, 2024
a0f7c33
Merge branch 'develop' into add_allen_cahn_example
HydrogenSulfate Apr 11, 2024
c9e0133
fix epoch_ema saving
HydrogenSulfate Apr 11, 2024
9d3063f
add unitest for ema state_dict
HydrogenSulfate Apr 11, 2024
186648b
refine allen_cahn_plain.py
HydrogenSulfate Apr 11, 2024
aa23fcc
fix string to floating conversion in reader.py
HydrogenSulfate Apr 11, 2024
3f56fcf
fix string to floating conversion in reader.py
HydrogenSulfate Apr 11, 2024
8d1dabf
update code and refine document
HydrogenSulfate Apr 12, 2024
9d4f4c0
Merge branch 'develop' into add_causal_train
HydrogenSulfate Apr 15, 2024
8e8e78a
correct initialization for RWF
HydrogenSulfate Apr 15, 2024
46f8be6
update docstring for arg 'random_weight' of mlp
HydrogenSulfate Apr 15, 2024
f212a3a
update docstrings
HydrogenSulfate Apr 15, 2024
e9bd799
add causal fourier rwf config
HydrogenSulfate Apr 15, 2024
c55f714
fix code in mlp.py
HydrogenSulfate Apr 16, 2024
92f4efe
refine code in mse.py
HydrogenSulfate Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/zh/api/loss/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- L2RelLoss
- MAELoss
- MSELoss
- CausalMSELoss
- MSELossWithL2Decay
- IntegralLoss
- PeriodicL1Loss
Expand Down
303 changes: 303 additions & 0 deletions examples/allen_cahn/allen_cahn_causal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
"""
Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn
"""

from os import path as osp

import hydra
import numpy as np
import paddle
import scipy.io as sio
from matplotlib import pyplot as plt
from omegaconf import DictConfig

import ppsci
from ppsci.utils import misc

dtype = paddle.get_default_dtype()


def plot(
t_star: np.ndarray,
x_star: np.ndarray,
u_ref: np.ndarray,
u_pred: np.ndarray,
output_dir: str,
):
fig = plt.figure(figsize=(18, 5))
TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
u_ref = u_ref.reshape([len(t_star), len(x_star)])

plt.subplot(1, 3, 1)
plt.pcolor(TT, XX, u_ref, cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Exact")
plt.tight_layout()

plt.subplot(1, 3, 2)
plt.pcolor(TT, XX, u_pred, cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Predicted")
plt.tight_layout()

plt.subplot(1, 3, 3)
plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Absolute error")
plt.tight_layout()

fig_path = osp.join(output_dir, "ac.png")
print(f"Saving figure to {fig_path}")
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
plt.close()


def train(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

# set equation
equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)}

data = sio.loadmat(cfg.DATA_PATH)
u_ref = data["usol"].astype(dtype) # (nt, nx)
t_star = data["t"].flatten().astype(dtype) # [nt, ]
x_star = data["x"].flatten().astype(dtype) # [nx, ]

u0 = u_ref[0, :] # [nx, ]

t0 = t_star[0] # float
t1 = t_star[-1] # float

x0 = x_star[0] # float
x1 = x_star[-1] # float

# set constraint
def gen_input_batch():
tx = np.random.uniform(
[t0, x0],
[t1, x1],
(cfg.TRAIN.batch_size, 2),
).astype(dtype)
return {
"t": np.sort(tx[:, 0:1], axis=0),
"x": tx[:, 1:2],
}

def gen_label_batch(input_batch):
return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)}

pde_constraint = ppsci.constraint.SupervisedConstraint(
{
"dataset": {
"name": "ContinuousNamedArrayDataset",
"input": gen_input_batch,
"label": gen_label_batch,
},
},
output_expr=equation["AllenCahn"].equations,
loss=ppsci.loss.CausalMSELoss(
cfg.TRAIN.causal.n_chunks, "mean", tol=cfg.TRAIN.causal.tol
),
name="PDE",
)

ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])}
ic_label = {"u": u0.reshape([-1, 1])}
ic = ppsci.constraint.SupervisedConstraint(
{
"dataset": {
"name": "IterableNamedArrayDataset",
"input": ic_input,
"label": ic_label,
},
},
output_expr={"u": lambda out: out["u"]},
loss=ppsci.loss.MSELoss("mean"),
name="IC",
)
# wrap constraints together
constraint = {
pde_constraint.name: pde_constraint,
ic.name: ic,
}

# set optimizer
lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay(
**cfg.TRAIN.lr_scheduler
)()
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

# set validator
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
eval_label = {"u": u_ref.reshape([-1, 1])}
u_validator = ppsci.validate.SupervisedValidator(
{
"dataset": {
"name": "NamedArrayDataset",
"input": eval_data,
"label": eval_label,
},
"batch_size": cfg.EVAL.batch_size,
},
ppsci.loss.MSELoss("mean"),
{"u": lambda out: out["u"]},
metric={"L2Rel": ppsci.metric.L2Rel()},
name="u_validator",
)
validator = {u_validator.name: u_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
cfg.output_dir,
optimizer,
lr_scheduler,
cfg.TRAIN.epochs,
cfg.TRAIN.iters_per_epoch,
save_freq=cfg.TRAIN.save_freq,
log_freq=cfg.log_freq,
eval_during_train=True,
eval_freq=cfg.TRAIN.eval_freq,
seed=cfg.seed,
equation=equation,
validator=validator,
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
checkpoint_path=cfg.TRAIN.checkpoint_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
use_tbd=True,
cfg=cfg,
)
# train model
solver.train()
# evaluate after finished training
solver.eval()
# visualize prediction after finished training
u_pred = solver.predict(
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
)["u"]
u_pred = u_pred.reshape([len(t_star), len(x_star)])

# plot
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)


def evaluate(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

data = sio.loadmat(cfg.DATA_PATH)
u_ref = data["usol"].astype(dtype) # (nt, nx)
t_star = data["t"].flatten().astype(dtype) # [nt, ]
x_star = data["x"].flatten().astype(dtype) # [nx, ]

# set validator
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
eval_label = {"u": u_ref.reshape([-1, 1])}
u_validator = ppsci.validate.SupervisedValidator(
{
"dataset": {
"name": "NamedArrayDataset",
"input": eval_data,
"label": eval_label,
},
"batch_size": cfg.EVAL.batch_size,
},
ppsci.loss.MSELoss("mean"),
{"u": lambda out: out["u"]},
metric={"L2Rel": ppsci.metric.L2Rel()},
name="u_validator",
)
validator = {u_validator.name: u_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
output_dir=cfg.output_dir,
log_freq=cfg.log_freq,
validator=validator,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

# evaluate after finished training
solver.eval()
# visualize prediction after finished training
u_pred = solver.predict(
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
)["u"]
u_pred = u_pred.reshape([len(t_star), len(x_star)])

# plot
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)


def export(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
solver.export(input_spec, cfg.INFER.export_path, with_onnx=False)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)
data = sio.loadmat(cfg.DATA_PATH)
u_ref = data["usol"].astype(dtype) # (nt, nx)
t_star = data["t"].flatten().astype(dtype) # [nt, ]
x_star = data["x"].flatten().astype(dtype) # [nx, ]
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)

input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
# mapping data to cfg.INFER.output_keys

plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)


@hydra.main(
version_base=None, config_path="./conf", config_name="allen_cahn_causal.yaml"
)
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
main()
5 changes: 3 additions & 2 deletions examples/allen_cahn/allen_cahn_plain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Reference: https://docs.nvidia.com/deeplearning/modulus/modulus-v2209/user_guide/intermediate/adding_stl_files.html
Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn
"""

from os import path as osp
Expand Down Expand Up @@ -53,6 +53,7 @@ def plot(
plt.tight_layout()

fig_path = osp.join(output_dir, "ac.png")
print(f"Saving figure to {fig_path}")
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
plt.close()

Expand Down Expand Up @@ -101,7 +102,7 @@ def gen_label_batch(input_batch):
},
},
output_expr=equation["AllenCahn"].equations,
loss=ppsci.loss.MSELoss(),
loss=ppsci.loss.MSELoss("mean"),
name="PDE",
)

Expand Down
3 changes: 0 additions & 3 deletions examples/allen_cahn/conf/allen_cahn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ TRAIN:
batch_size: 4096
pretrained_model_path: null
checkpoint_path: null
ema:
decay: 0.9
avg_freq: 1

# evaluation settings
EVAL:
Expand Down
Loading