Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add AMGNet example #549

Merged
merged 29 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
de89025
update WIP code
HydrogenSulfate Sep 20, 2023
38659ca
(WIP)update AMGNet code
HydrogenSulfate Sep 22, 2023
24d4d09
try import pgl to avoid importerror
HydrogenSulfate Sep 22, 2023
f1bd744
try import pyamg to avoid importerror
HydrogenSulfate Sep 22, 2023
83f707c
add airfoil_dataset.py
HydrogenSulfate Sep 22, 2023
e770842
add type checking for amgnet
HydrogenSulfate Sep 22, 2023
e773e20
try import pgl to avoid importerror
HydrogenSulfate Sep 22, 2023
fa2adda
refine Timer
HydrogenSulfate Sep 22, 2023
c524a27
replace pgl.Dataset with io.Dataset
HydrogenSulfate Sep 22, 2023
66954a6
update reproded code
HydrogenSulfate Sep 26, 2023
ac13693
replace ImportError with ModuleNotFoundError
HydrogenSulfate Sep 26, 2023
f0ae844
refine amgnet.py
HydrogenSulfate Sep 26, 2023
283301c
refine amgnet_airfoil.py and amgnet_cylinder.py
HydrogenSulfate Sep 26, 2023
8fee479
refine utils.py
HydrogenSulfate Sep 26, 2023
f242756
refine collate_fn
HydrogenSulfate Sep 26, 2023
031f210
fix bug in eval.py
HydrogenSulfate Sep 26, 2023
781d217
refine codes
HydrogenSulfate Oct 8, 2023
2b8b754
refine codes
HydrogenSulfate Oct 8, 2023
c3e560a
modify atol from 1e-8 to 1e-7 of UT test_navierstokes
HydrogenSulfate Oct 8, 2023
997fa3f
refine code
HydrogenSulfate Oct 9, 2023
3a6e985
add AMGNet document
HydrogenSulfate Oct 9, 2023
d593755
fix
HydrogenSulfate Oct 11, 2023
11c803d
fix
HydrogenSulfate Oct 11, 2023
2381290
avoid tensor converion in dataset, and move in to collate_fn
HydrogenSulfate Oct 17, 2023
a59bc28
Merge branch 'develop' into add_AMGNet
HydrogenSulfate Oct 17, 2023
7eb4aab
update final code
HydrogenSulfate Oct 18, 2023
876ef4d
add example for AMGNet
HydrogenSulfate Oct 18, 2023
9c04c5b
fix doc
HydrogenSulfate Oct 19, 2023
a55ae30
Merge branch 'develop' into add_AMGNet
HydrogenSulfate Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/zh/api/data/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@
- LorenzDataset
- RosslerDataset
- VtuDataset
- MeshAirfoilDataset
- MeshCylinderDataset
show_root_heading: false
162 changes: 162 additions & 0 deletions examples/amgnet/amgnet_airfoil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Dict
from typing import List

import utils
from paddle.nn import functional as F

import ppsci
from ppsci.utils import config
from ppsci.utils import logger

if TYPE_CHECKING:
import paddle
import pgl


def train_mse_func(
output_dict: Dict[str, "paddle.Tensor"], label_dict: Dict[str, "pgl.Graph"], *args
) -> paddle.Tensor:
return F.mse_loss(output_dict["pred"], label_dict["label"].y)


def eval_rmse_func(
output_dict: Dict[str, List["paddle.Tensor"]],
label_dict: Dict[str, List["pgl.Graph"]],
*args,
) -> Dict[str, float]:
mse_losses = [
F.mse_loss(pred, label.y)
for (pred, label) in zip(output_dict["pred"], label_dict["label"])
]
return {"RMSE": (sum(mse_losses) / len(mse_losses)) ** 0.5}


if __name__ == "__main__":
args = config.parse_args()
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)
# set output directory
OUTPUT_DIR = "./output_AMGNet" if not args.output_dir else args.output_dir
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")

# set airfoil model
model = ppsci.arch.AMGNet(
input_dim=5,
output_dim=3,
latent_dim=128,
num_layers=2,
message_passing_aggregator="sum",
message_passing_steps=6,
speed="norm",
)

# set dataloader config
ITERS_PER_EPOCH = 42
train_dataloader_cfg = {
"dataset": {
"name": "MeshAirfoilDataset",
"input_keys": ("input",),
"label_keys": ("label",),
"data_root": "./data/NACA0012_interpolate/outputs_train",
"mesh_graph_path": "./data/NACA0012_interpolate/mesh_fine.su2",
},
"batch_size": 4,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
"num_workers": 1,
}

# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
train_dataloader_cfg,
output_expr={"pred": lambda out: out["pred"]},
loss=ppsci.loss.FunctionalLoss(train_mse_func),
name="Sup",
)
# wrap constraints together
constraint = {sup_constraint.name: sup_constraint}

# set training hyper-parameters
EPOCHS = 500 if not args.epochs else args.epochs

# set optimizer
optimizer = ppsci.optimizer.Adam(5e-4)(model)

# set validator
eval_dataloader_cfg = {
"dataset": {
"name": "MeshAirfoilDataset",
"input_keys": ("input",),
"label_keys": ("label",),
"data_root": "./data/NACA0012_interpolate/outputs_test",
"mesh_graph_path": "./data/NACA0012_interpolate/mesh_fine.su2",
},
"batch_size": 1,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
}
rmse_validator = ppsci.validate.SupervisedValidator(
eval_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(train_mse_func),
output_expr={"pred": lambda out: out["pred"]},
metric={"RMSE": ppsci.metric.FunctionalMetric(eval_rmse_func)},
name="RMSE_validator",
)
validator = {rmse_validator.name: rmse_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
OUTPUT_DIR,
optimizer,
None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以去掉吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个None参数属于位置参数,不能删除,删了的话代码行数会变动

EPOCHS,
ITERS_PER_EPOCH,
save_freq=50,
eval_during_train=True,
eval_freq=50,
validator=validator,
eval_with_no_grad=True,
)
# train model
solver.train()

# visualize prediction
with solver.no_grad_context_manager(True):
for index, batch in enumerate(rmse_validator.data_loader):
truefield = batch[0]["input"].y
prefield = model(batch[0])
utils.log_images(
batch[0]["input"].pos,
prefield["pred"],
truefield,
rmse_validator.data_loader.dataset.elems_list,
"test",
index,
flag="airfoil",
)
162 changes: 162 additions & 0 deletions examples/amgnet/amgnet_cylinder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Dict
from typing import List

import utils
from paddle.nn import functional as F

import ppsci
from ppsci.utils import config
from ppsci.utils import logger

if TYPE_CHECKING:
import paddle
import pgl


def train_mse_func(
output_dict: Dict[str, "paddle.Tensor"], label_dict: Dict[str, "pgl.Graph"], *args
) -> paddle.Tensor:
return F.mse_loss(output_dict["pred"], label_dict["label"].y)


def eval_rmse_func(
output_dict: Dict[str, List["paddle.Tensor"]],
label_dict: Dict[str, List["pgl.Graph"]],
*args,
) -> Dict[str, float]:
mse_losses = [
F.mse_loss(pred, label.y)
for (pred, label) in zip(output_dict["pred"], label_dict["label"])
]
return {"RMSE": (sum(mse_losses) / len(mse_losses)) ** 0.5}


if __name__ == "__main__":
args = config.parse_args()
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)
# set output directory
OUTPUT_DIR = "./output_AMGNet_Cylinder" if not args.output_dir else args.output_dir
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")

# set cylinder model
model = ppsci.arch.AMGNet(
input_dim=4,
output_dim=3,
latent_dim=128,
num_layers=2,
message_passing_aggregator="sum",
message_passing_steps=6,
speed="norm",
)

# set dataloader config
ITERS_PER_EPOCH = 42
train_dataloader_cfg = {
"dataset": {
"name": "MeshCylinderDataset",
"input_keys": ("input",),
"label_keys": ("label",),
"data_root": "./data/cylinderdata/train",
"mesh_graph_path": "./data/cylinderdata/cylinder.su2",
},
"batch_size": 4,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
"num_workers": 1,
}

# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
train_dataloader_cfg,
output_expr={"pred": lambda out: out["pred"]},
loss=ppsci.loss.FunctionalLoss(train_mse_func),
name="Sup",
)
# wrap constraints together
constraint = {sup_constraint.name: sup_constraint}

# set training hyper-parameters
EPOCHS = 500 if not args.epochs else args.epochs

# set optimizer
optimizer = ppsci.optimizer.Adam(5e-4)(model)

# set validator
eval_dataloader_cfg = {
"dataset": {
"name": "MeshCylinderDataset",
"input_keys": ("input",),
"label_keys": ("label",),
"data_root": "./data/cylinderdata/test",
"mesh_graph_path": "./data/cylinderdata/cylinder.su2",
},
"batch_size": 1,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
}
rmse_validator = ppsci.validate.SupervisedValidator(
eval_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(train_mse_func),
output_expr={"pred": lambda out: out["pred"]},
metric={"RMSE": ppsci.metric.FunctionalMetric(eval_rmse_func)},
name="RMSE_validator",
)
validator = {rmse_validator.name: rmse_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
OUTPUT_DIR,
optimizer,
None,
EPOCHS,
ITERS_PER_EPOCH,
save_freq=50,
eval_during_train=True,
eval_freq=50,
validator=validator,
eval_with_no_grad=True,
)
# train model
solver.train()

# visualize prediction
with solver.no_grad_context_manager(True):
for index, batch in enumerate(rmse_validator.data_loader):
truefield = batch[0]["input"].y
prefield = model(batch[0])
utils.log_images(
batch[0]["input"].pos,
prefield["pred"],
truefield,
rmse_validator.data_loader.dataset.elems_list,
"test",
index,
flag="cylinder",
)
Loading