Skip to content

Commit

Permalink
(WIP)update AMGNet code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 22, 2023
1 parent de89025 commit 7b71329
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 51 deletions.
1 change: 1 addition & 0 deletions docs/zh/api/data/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
- LorenzDataset
- RosslerDataset
- VtuDataset
- MeshAirfoilDataset
show_root_heading: false
170 changes: 170 additions & 0 deletions examples/amgnet/amgnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.nn import functional as F

import ppsci
from ppsci.utils import config
from ppsci.utils import logger


def train_mse_func(output_dict, label_dict, *args):
return F.mse_loss(output_dict["pred"], label_dict["label"].y)


def eval_rmse_func(output_dict, label_dict, *args):
mse_losses = [
F.mse_loss(pred, label.y)
for (pred, label) in zip(output_dict["pred"], label_dict["label"])
]
return {"RMSE": (sum(mse_losses) / len(mse_losses)) ** 0.5}


if __name__ == "__main__":
args = config.parse_args()
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)
# set output directory
OUTPUT_DIR = "./output_AMGNet" if not args.output_dir else args.output_dir
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")

# set model
# model for airfoil
model = ppsci.arch.AMGNet(
5,
3,
128,
num_layers=2,
message_passing_aggregator="sum",
message_passing_steps=6,
speed="norm",
)
# # model for cylinder
# model = ppsci.arch.AMGNet(
# 5,
# 3,
# 128,
# num_layers=2,
# message_passing_aggregator="sum",
# message_passing_steps=6,
# speed="norm",
# )

# set dataloader config
ITERS_PER_EPOCH = 42
train_dataloader_cfg = {
"dataset": {
"name": "MeshAirfoilDataset",
"input_keys": ("input",),
"label_keys": ("label",),
"data_root": "./data/NACA0012_interpolate/outputs_train",
"mesh_graph_path": "./data/NACA0012_interpolate/mesh_fine.su2",
},
"batch_size": 4,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
"num_workers": 1,
}

# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
train_dataloader_cfg,
output_expr={"pred": lambda out: out["pred"]},
loss=ppsci.loss.FunctionalLoss(train_mse_func),
name="Sup",
)
# wrap constraints together
constraint = {
sup_constraint.name: sup_constraint,
}

# set training hyper-parameters
EPOCHS = 500 if not args.epochs else args.epochs

# set optimizer
optimizer = ppsci.optimizer.Adam(5e-4)(model)

# set validator
eval_dataloader_cfg = {
"dataset": {
"name": "MeshAirfoilDataset",
"input_keys": ("input",),
"label_keys": ("label",),
"data_root": "./data/NACA0012_interpolate/outputs_test",
"mesh_graph_path": "./data/NACA0012_interpolate/mesh_fine.su2",
},
"batch_size": 1,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
}
rmse_validator = ppsci.validate.SupervisedValidator(
eval_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(train_mse_func),
output_expr={"pred": lambda out: out["pred"]},
metric={"RMSE": ppsci.metric.FunctionalMetric(eval_rmse_func)},
name="RMSE_validator",
)
validator = {rmse_validator.name: rmse_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
OUTPUT_DIR,
optimizer,
None,
EPOCHS,
ITERS_PER_EPOCH,
eval_during_train=True,
eval_freq=50,
validator=validator,
eval_with_no_grad=True,
# pretrained_model_path="./output_AMGNet/checkpoints/latest"
)
# train model
solver.train()
# solver.eval()
# with solver.no_grad_context_manager(True):
# sum_loss = 0
# for index, batch in enumerate(loader):
# truefield = batch[0].y
# prefield = model(batch)
# # print(f"{index }prefield.mean() = {prefield.shape} {prefield.mean().item():.10f}")
# # log_images(
# # batch[0].pos,
# # prefield,
# # truefield,
# # trainer.data.elems_list,
# # "test",
# # index,
# # flag=my_type,
# # )
# mes_loss = criterion(prefield, truefield)
# # print(f">>> mes_loss = {mes_loss.item():.10f}")
# sum_loss += mes_loss.item()
# print(index)
# # exit()
# avg_loss = sum_loss / (len(loader))
# avg_loss = np.sqrt(avg_loss)
# root_logger.info(" trajectory_loss")
# root_logger.info(" " + str(avg_loss))
# print("trajectory_loss=", avg_loss)
# print("============finish============")
56 changes: 30 additions & 26 deletions ppsci/arch/amgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def StAS(index_A, value_A, index_S, value_S, N, kN, nor):

indices_A = index_A.numpy()
values_A = value_A.numpy()
# with misc.Timer("coo_matrix1"):
coo_A = sci_sparse.coo_matrix(
(values_A, (indices_A[0], indices_A[1])), shape=(N, N)
)
Expand Down Expand Up @@ -125,6 +126,7 @@ def StAS(index_A, value_A, index_S, value_S, N, kN, nor):

indices_A = index_St.numpy()
values_A = value_St.numpy()
# with misc.Timer("coo_matrix2"):
coo_A = sci_sparse.coo_matrix(
(values_A, (indices_A[0], indices_A[1])), shape=(kN, N)
)
Expand Down Expand Up @@ -259,21 +261,26 @@ def norm_graph_connectivity(perm, edge_index, edge_weight, score, pos, N, nor):
mask = (edge_index[0] == perm2).sum(axis=0).astype("bool")
S0 = edge_index[1][mask].reshape((1, -1))
S1 = edge_index[0][mask].reshape((1, -1))

index_S = paddle.concat([S0, S1], axis=0)
value_S = score[mask].detach().squeeze()
n_idx = paddle.zeros([N], dtype=paddle.int64)
n_idx[perm] = paddle.arange(perm.shape[0])

index_S = index_S.astype("int64")
index_S[1] = n_idx[index_S[1]]
subgraphnode_pos = pos[perm]
index_A = edge_index.clone()

if edge_weight is None:
value_A = value_S.new_ones(edge_index[0].shape[0])
else:
value_A = edge_weight.clone()

value_A = paddle.squeeze(value_A)
value_S = paddle.where(value_S == 0, paddle.to_tensor(0.0001), value_S)
eps_mask = (value_S == 0).astype(paddle.get_default_dtype())
value_S = paddle.full_like(value_S, 1e-4) * eps_mask + (1 - eps_mask) * value_S
# value_S = paddle.where(value_S == 0, paddle.to_tensor(0.0001), value_S)
attrlist = []
standard_index, _ = StAS(
index_A,
Expand All @@ -284,15 +291,21 @@ def norm_graph_connectivity(perm, edge_index, edge_weight, score, pos, N, nor):
kN,
nor,
)
# with misc.Timer("range 128"):
for i in range(128):
val_A = paddle.where(
value_A[:, i] == 0, paddle.to_tensor(0.0001), value_A[:, i]
)
mask = (value_A[:, i] == 0).astype(paddle.get_default_dtype())
val_A = paddle.full_like(mask, 1e-4) * mask + (1 - mask) * value_A[:, i]
# val_A = paddle.where(
# value_A[:, i] == 0, paddle.to_tensor(0.0001), value_A[:, i]
# )
# with misc.Timer("inner StAS"):
index_E, value_E = StAS(index_A, val_A, index_S, value_S, N, kN, nor)

if index_E.shape[1] != standard_index.shape[1]:
# with misc.Timer("FillZeros"):
index_E, value_E = FillZeros(index_E, value_E, standard_index, kN)

# with misc.Timer("remove_self_loops"):
index_E, value_E = remove_self_loops(edge_index=index_E, edge_attr=value_E)
attrlist.append(value_E)
edge_weight = paddle.stack(attrlist, axis=1)
Expand Down Expand Up @@ -445,6 +458,7 @@ def forward(self, latent_graph, speed, normalized_adj_mat=None):
nor=self.normalization,
)
elif speed == "norm":
# with misc.Timer("norm_graph_connectivity"):
subedge_index, edge_weight, subpos = norm_graph_connectivity(
perm=coarsenodes,
edge_index=cofe_graph.edge_index,
Expand Down Expand Up @@ -498,7 +512,7 @@ def forward(self, input):
class Encoder(nn.Layer):
"""Encodes node and edge features into latent features."""

def __init__(self, input_dim, make_mlp, latent_dim, mode):
def __init__(self, input_dim, make_mlp, latent_dim):
super(Encoder, self).__init__()
self._make_mlp = make_mlp
self._latent_dim = latent_dim
Expand All @@ -507,7 +521,7 @@ def __init__(self, input_dim, make_mlp, latent_dim, mode):
# else:
self.node_model = self._make_mlp(latent_dim, input_dim=input_dim) # 4

self.mesh_edge_model = self._make_mlp(latent_dim, input_dim=input_dim) # 1
self.mesh_edge_model = self._make_mlp(latent_dim, input_dim=1) # 1
"""
for _ in graph.edge_sets:
edge_model = make_mlp(latent_dim)
Expand All @@ -516,9 +530,8 @@ def __init__(self, input_dim, make_mlp, latent_dim, mode):

def forward(self, graph):
node_latents = self.node_model(graph.x)
# print(f">>> graph.x = {graph.x.shape}") # [26736, 5]
edge_latent = self.mesh_edge_model(graph.edge_attr)
# print(f">>> graph.edge_attr = {graph.edge_attr.shape}") # [105344, 1]

graph.x = node_latents
graph.edge_attr = edge_latent
return graph
Expand Down Expand Up @@ -548,7 +561,6 @@ def __init__(
num_layers,
message_passing_aggregator,
message_passing_steps,
mode,
speed,
nodes=6684,
):
Expand All @@ -560,10 +572,7 @@ def __init__(
self.min_nodes = nodes
self._message_passing_steps = message_passing_steps
self._message_passing_aggregator = message_passing_aggregator
# self.mode = mode
self.encoder = Encoder(
make_mlp=self._make_mlp, latent_dim=self._latent_dim, mode=self.mode
)
self.encoder = Encoder(input_dim, self._make_mlp, latent_dim=self._latent_dim)
self.processor = Processor(
make_mlp=self._make_mlp,
output_dim=self._latent_dim,
Expand Down Expand Up @@ -596,19 +605,14 @@ def _spa_compute(self, x, p):
node_features = self.post_processor(node_features)
return node_features

def forward(self, graphs):
batch = MyCopy(graphs[0])

for index, graph in enumerate(graphs):
if index > 0:
batch = Myadd(batch, graph)

latent_graph = self.encoder(batch)

def forward(self, x):
graphs = x["input"]
# with misc.Timer("encoder"):
latent_graph = self.encoder(graphs)
# with misc.Timer("processor"):
x, p = self.processor(latent_graph, speed=self.speed)

# with misc.Timer("_spa_compute"):
node_features = self._spa_compute(x, p)

# with misc.Timer("decoder"):
pred_field = self.decoder(node_features)

return pred_field
return {"pred": pred_field}
35 changes: 24 additions & 11 deletions ppsci/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def build_dataloader(_dataset, cfg):
# build collate_fn if specified
batch_transforms_cfg = cfg.pop("batch_transforms", None)

collate_fn = None
if isinstance(batch_transforms_cfg, dict) and batch_transforms_cfg:
collate_fn = batch_transform.default_collate_fn
if isinstance(batch_transforms_cfg, (list, tuple)):
collate_fn = batch_transform.build_batch_transforms(batch_transforms_cfg)

# build init function
Expand All @@ -96,15 +96,28 @@ def build_dataloader(_dataset, cfg):
)

# build dataloader
dataloader_ = io.DataLoader(
dataset=_dataset,
places=device.get_device(),
batch_sampler=sampler,
collate_fn=collate_fn,
num_workers=cfg.get("num_workers", 0),
use_shared_memory=cfg.get("use_shared_memory", False),
worker_init_fn=init_fn,
)
if getattr(_dataset, "use_pgl", False):
# Use special dataloader from "Paddle Graph Learning" toolkit.
from pgl.utils import data as pgl_data

dataloader_ = pgl_data.Dataloader(
dataset=_dataset,
batch_size=cfg["batch_size"],
drop_last=sampler_cfg.get("drop_last", False),
shuffle=sampler_cfg.get("shuffle", False),
num_workers=cfg.get("num_workers", 1),
collate_fn=collate_fn,
)
else:
dataloader_ = io.DataLoader(
dataset=_dataset,
places=device.get_device(),
batch_sampler=sampler,
collate_fn=collate_fn,
num_workers=cfg.get("num_workers", 0),
use_shared_memory=cfg.get("use_shared_memory", False),
worker_init_fn=init_fn,
)

if len(dataloader_) == 0:
raise ValueError(
Expand Down
2 changes: 2 additions & 0 deletions ppsci/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy

from ppsci.data.dataset.airfoil_dataset import MeshAirfoilDataset
from ppsci.data.dataset.array_dataset import IterableNamedArrayDataset
from ppsci.data.dataset.array_dataset import NamedArrayDataset
from ppsci.data.dataset.csv_dataset import CSVDataset
Expand Down Expand Up @@ -46,6 +47,7 @@
"LorenzDataset",
"RosslerDataset",
"VtuDataset",
"MeshAirfoilDataset",
"build_dataset",
]

Expand Down
Loading

0 comments on commit 7b71329

Please sign in to comment.