diff --git a/docs/zh/api/data/dataset.md b/docs/zh/api/data/dataset.md index f2e8f0708..1f65d38a7 100644 --- a/docs/zh/api/data/dataset.md +++ b/docs/zh/api/data/dataset.md @@ -16,4 +16,5 @@ - LorenzDataset - RosslerDataset - VtuDataset + - MeshAirfoilDataset show_root_heading: false diff --git a/examples/amgnet/amgnet.py b/examples/amgnet/amgnet.py new file mode 100644 index 000000000..81d868536 --- /dev/null +++ b/examples/amgnet/amgnet.py @@ -0,0 +1,170 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.nn import functional as F + +import ppsci +from ppsci.utils import config +from ppsci.utils import logger + + +def train_mse_func(output_dict, label_dict, *args): + return F.mse_loss(output_dict["pred"], label_dict["label"].y) + + +def eval_rmse_func(output_dict, label_dict, *args): + mse_losses = [ + F.mse_loss(pred, label.y) + for (pred, label) in zip(output_dict["pred"], label_dict["label"]) + ] + return {"RMSE": (sum(mse_losses) / len(mse_losses)) ** 0.5} + + +if __name__ == "__main__": + args = config.parse_args() + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(42) + # set output directory + OUTPUT_DIR = "./output_AMGNet" if not args.output_dir else args.output_dir + # initialize logger + logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + + # set model + # model for airfoil + model = ppsci.arch.AMGNet( + 5, + 3, + 128, + num_layers=2, + message_passing_aggregator="sum", + message_passing_steps=6, + speed="norm", + ) + # # model for cylinder + # model = ppsci.arch.AMGNet( + # 5, + # 3, + # 128, + # num_layers=2, + # message_passing_aggregator="sum", + # message_passing_steps=6, + # speed="norm", + # ) + + # set dataloader config + ITERS_PER_EPOCH = 42 + train_dataloader_cfg = { + "dataset": { + "name": "MeshAirfoilDataset", + "input_keys": ("input",), + "label_keys": ("label",), + "data_root": "./data/NACA0012_interpolate/outputs_train", + "mesh_graph_path": "./data/NACA0012_interpolate/mesh_fine.su2", + }, + "batch_size": 4, + "sampler": { + "name": "BatchSampler", + "drop_last": False, + "shuffle": True, + }, + "num_workers": 1, + } + + # set constraint + sup_constraint = ppsci.constraint.SupervisedConstraint( + train_dataloader_cfg, + output_expr={"pred": lambda out: out["pred"]}, + loss=ppsci.loss.FunctionalLoss(train_mse_func), + name="Sup", + ) + # wrap constraints together + constraint = { + sup_constraint.name: sup_constraint, + } + + # set training hyper-parameters + EPOCHS = 500 if not args.epochs else args.epochs + + # set optimizer + optimizer = ppsci.optimizer.Adam(5e-4)(model) + + # set validator + eval_dataloader_cfg = { + "dataset": { + "name": "MeshAirfoilDataset", + "input_keys": ("input",), + "label_keys": ("label",), + "data_root": "./data/NACA0012_interpolate/outputs_test", + "mesh_graph_path": "./data/NACA0012_interpolate/mesh_fine.su2", + }, + "batch_size": 1, + "sampler": { + "name": "BatchSampler", + "drop_last": False, + "shuffle": False, + }, + } + rmse_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + loss=ppsci.loss.FunctionalLoss(train_mse_func), + output_expr={"pred": lambda out: out["pred"]}, + metric={"RMSE": ppsci.metric.FunctionalMetric(eval_rmse_func)}, + name="RMSE_validator", + ) + validator = {rmse_validator.name: rmse_validator} + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + OUTPUT_DIR, + optimizer, + None, + EPOCHS, + ITERS_PER_EPOCH, + eval_during_train=True, + eval_freq=50, + validator=validator, + eval_with_no_grad=True, + # pretrained_model_path="./output_AMGNet/checkpoints/latest" + ) + # train model + solver.train() + # solver.eval() + # with solver.no_grad_context_manager(True): + # sum_loss = 0 + # for index, batch in enumerate(loader): + # truefield = batch[0].y + # prefield = model(batch) + # # print(f"{index }prefield.mean() = {prefield.shape} {prefield.mean().item():.10f}") + # # log_images( + # # batch[0].pos, + # # prefield, + # # truefield, + # # trainer.data.elems_list, + # # "test", + # # index, + # # flag=my_type, + # # ) + # mes_loss = criterion(prefield, truefield) + # # print(f">>> mes_loss = {mes_loss.item():.10f}") + # sum_loss += mes_loss.item() + # print(index) + # # exit() + # avg_loss = sum_loss / (len(loader)) + # avg_loss = np.sqrt(avg_loss) + # root_logger.info(" trajectory_loss") + # root_logger.info(" " + str(avg_loss)) + # print("trajectory_loss=", avg_loss) + # print("============finish============") diff --git a/ppsci/arch/amgnet.py b/ppsci/arch/amgnet.py index d7c0a4f86..769990491 100644 --- a/ppsci/arch/amgnet.py +++ b/ppsci/arch/amgnet.py @@ -95,6 +95,7 @@ def StAS(index_A, value_A, index_S, value_S, N, kN, nor): indices_A = index_A.numpy() values_A = value_A.numpy() + # with misc.Timer("coo_matrix1"): coo_A = sci_sparse.coo_matrix( (values_A, (indices_A[0], indices_A[1])), shape=(N, N) ) @@ -125,6 +126,7 @@ def StAS(index_A, value_A, index_S, value_S, N, kN, nor): indices_A = index_St.numpy() values_A = value_St.numpy() + # with misc.Timer("coo_matrix2"): coo_A = sci_sparse.coo_matrix( (values_A, (indices_A[0], indices_A[1])), shape=(kN, N) ) @@ -259,21 +261,26 @@ def norm_graph_connectivity(perm, edge_index, edge_weight, score, pos, N, nor): mask = (edge_index[0] == perm2).sum(axis=0).astype("bool") S0 = edge_index[1][mask].reshape((1, -1)) S1 = edge_index[0][mask].reshape((1, -1)) + index_S = paddle.concat([S0, S1], axis=0) value_S = score[mask].detach().squeeze() n_idx = paddle.zeros([N], dtype=paddle.int64) n_idx[perm] = paddle.arange(perm.shape[0]) + index_S = index_S.astype("int64") index_S[1] = n_idx[index_S[1]] subgraphnode_pos = pos[perm] index_A = edge_index.clone() + if edge_weight is None: value_A = value_S.new_ones(edge_index[0].shape[0]) else: value_A = edge_weight.clone() value_A = paddle.squeeze(value_A) - value_S = paddle.where(value_S == 0, paddle.to_tensor(0.0001), value_S) + eps_mask = (value_S == 0).astype(paddle.get_default_dtype()) + value_S = paddle.full_like(value_S, 1e-4) * eps_mask + (1 - eps_mask) * value_S + # value_S = paddle.where(value_S == 0, paddle.to_tensor(0.0001), value_S) attrlist = [] standard_index, _ = StAS( index_A, @@ -284,15 +291,21 @@ def norm_graph_connectivity(perm, edge_index, edge_weight, score, pos, N, nor): kN, nor, ) + # with misc.Timer("range 128"): for i in range(128): - val_A = paddle.where( - value_A[:, i] == 0, paddle.to_tensor(0.0001), value_A[:, i] - ) + mask = (value_A[:, i] == 0).astype(paddle.get_default_dtype()) + val_A = paddle.full_like(mask, 1e-4) * mask + (1 - mask) * value_A[:, i] + # val_A = paddle.where( + # value_A[:, i] == 0, paddle.to_tensor(0.0001), value_A[:, i] + # ) + # with misc.Timer("inner StAS"): index_E, value_E = StAS(index_A, val_A, index_S, value_S, N, kN, nor) if index_E.shape[1] != standard_index.shape[1]: + # with misc.Timer("FillZeros"): index_E, value_E = FillZeros(index_E, value_E, standard_index, kN) + # with misc.Timer("remove_self_loops"): index_E, value_E = remove_self_loops(edge_index=index_E, edge_attr=value_E) attrlist.append(value_E) edge_weight = paddle.stack(attrlist, axis=1) @@ -445,6 +458,7 @@ def forward(self, latent_graph, speed, normalized_adj_mat=None): nor=self.normalization, ) elif speed == "norm": + # with misc.Timer("norm_graph_connectivity"): subedge_index, edge_weight, subpos = norm_graph_connectivity( perm=coarsenodes, edge_index=cofe_graph.edge_index, @@ -498,7 +512,7 @@ def forward(self, input): class Encoder(nn.Layer): """Encodes node and edge features into latent features.""" - def __init__(self, input_dim, make_mlp, latent_dim, mode): + def __init__(self, input_dim, make_mlp, latent_dim): super(Encoder, self).__init__() self._make_mlp = make_mlp self._latent_dim = latent_dim @@ -507,7 +521,7 @@ def __init__(self, input_dim, make_mlp, latent_dim, mode): # else: self.node_model = self._make_mlp(latent_dim, input_dim=input_dim) # 4 - self.mesh_edge_model = self._make_mlp(latent_dim, input_dim=input_dim) # 1 + self.mesh_edge_model = self._make_mlp(latent_dim, input_dim=1) # 1 """ for _ in graph.edge_sets: edge_model = make_mlp(latent_dim) @@ -516,9 +530,8 @@ def __init__(self, input_dim, make_mlp, latent_dim, mode): def forward(self, graph): node_latents = self.node_model(graph.x) - # print(f">>> graph.x = {graph.x.shape}") # [26736, 5] edge_latent = self.mesh_edge_model(graph.edge_attr) - # print(f">>> graph.edge_attr = {graph.edge_attr.shape}") # [105344, 1] + graph.x = node_latents graph.edge_attr = edge_latent return graph @@ -548,7 +561,6 @@ def __init__( num_layers, message_passing_aggregator, message_passing_steps, - mode, speed, nodes=6684, ): @@ -560,10 +572,7 @@ def __init__( self.min_nodes = nodes self._message_passing_steps = message_passing_steps self._message_passing_aggregator = message_passing_aggregator - # self.mode = mode - self.encoder = Encoder( - make_mlp=self._make_mlp, latent_dim=self._latent_dim, mode=self.mode - ) + self.encoder = Encoder(input_dim, self._make_mlp, latent_dim=self._latent_dim) self.processor = Processor( make_mlp=self._make_mlp, output_dim=self._latent_dim, @@ -596,19 +605,14 @@ def _spa_compute(self, x, p): node_features = self.post_processor(node_features) return node_features - def forward(self, graphs): - batch = MyCopy(graphs[0]) - - for index, graph in enumerate(graphs): - if index > 0: - batch = Myadd(batch, graph) - - latent_graph = self.encoder(batch) - + def forward(self, x): + graphs = x["input"] + # with misc.Timer("encoder"): + latent_graph = self.encoder(graphs) + # with misc.Timer("processor"): x, p = self.processor(latent_graph, speed=self.speed) - + # with misc.Timer("_spa_compute"): node_features = self._spa_compute(x, p) - + # with misc.Timer("decoder"): pred_field = self.decoder(node_features) - - return pred_field + return {"pred": pred_field} diff --git a/ppsci/data/__init__.py b/ppsci/data/__init__.py index f39a663ff..0faea3beb 100644 --- a/ppsci/data/__init__.py +++ b/ppsci/data/__init__.py @@ -83,8 +83,8 @@ def build_dataloader(_dataset, cfg): # build collate_fn if specified batch_transforms_cfg = cfg.pop("batch_transforms", None) - collate_fn = None - if isinstance(batch_transforms_cfg, dict) and batch_transforms_cfg: + collate_fn = batch_transform.default_collate_fn + if isinstance(batch_transforms_cfg, (list, tuple)): collate_fn = batch_transform.build_batch_transforms(batch_transforms_cfg) # build init function @@ -96,15 +96,28 @@ def build_dataloader(_dataset, cfg): ) # build dataloader - dataloader_ = io.DataLoader( - dataset=_dataset, - places=device.get_device(), - batch_sampler=sampler, - collate_fn=collate_fn, - num_workers=cfg.get("num_workers", 0), - use_shared_memory=cfg.get("use_shared_memory", False), - worker_init_fn=init_fn, - ) + if getattr(_dataset, "use_pgl", False): + # Use special dataloader from "Paddle Graph Learning" toolkit. + from pgl.utils import data as pgl_data + + dataloader_ = pgl_data.Dataloader( + dataset=_dataset, + batch_size=cfg["batch_size"], + drop_last=sampler_cfg.get("drop_last", False), + shuffle=sampler_cfg.get("shuffle", False), + num_workers=cfg.get("num_workers", 1), + collate_fn=collate_fn, + ) + else: + dataloader_ = io.DataLoader( + dataset=_dataset, + places=device.get_device(), + batch_sampler=sampler, + collate_fn=collate_fn, + num_workers=cfg.get("num_workers", 0), + use_shared_memory=cfg.get("use_shared_memory", False), + worker_init_fn=init_fn, + ) if len(dataloader_) == 0: raise ValueError( diff --git a/ppsci/data/dataset/__init__.py b/ppsci/data/dataset/__init__.py index 58fb2c1f8..6a8412fdc 100644 --- a/ppsci/data/dataset/__init__.py +++ b/ppsci/data/dataset/__init__.py @@ -14,6 +14,7 @@ import copy +from ppsci.data.dataset.airfoil_dataset import MeshAirfoilDataset from ppsci.data.dataset.array_dataset import IterableNamedArrayDataset from ppsci.data.dataset.array_dataset import NamedArrayDataset from ppsci.data.dataset.csv_dataset import CSVDataset @@ -46,6 +47,7 @@ "LorenzDataset", "RosslerDataset", "VtuDataset", + "MeshAirfoilDataset", "build_dataset", ] diff --git a/ppsci/data/process/batch_transform/__init__.py b/ppsci/data/process/batch_transform/__init__.py index 2887aa027..91c44c8f9 100644 --- a/ppsci/data/process/batch_transform/__init__.py +++ b/ppsci/data/process/batch_transform/__init__.py @@ -22,9 +22,14 @@ import numpy as np import paddle +try: + import pgl +except ImportError: + pass + from ppsci.data.process import transform -__all__ = ["build_batch_transforms"] +__all__ = ["build_batch_transforms", "default_collate_fn"] def default_collate_fn(batch: List[Any]) -> Any: @@ -39,7 +44,23 @@ def default_collate_fn(batch: List[Any]) -> Any: Any: Collated batch data. """ sample = batch[0] - if isinstance(sample, np.ndarray): + if sample is None: + return None + elif isinstance(sample, pgl.Graph): + graph = pgl.Graph( + num_nodes=sample.num_nodes, + edges=sample.edges, + ) + graph.x = paddle.concat([g.x for g in batch]) + graph.y = paddle.concat([g.y for g in batch]) + graph.edge_index = paddle.concat([g.edge_index for g in batch], axis=1) + graph.edge_attr = paddle.concat([g.edge_attr for g in batch]) + graph.pos = paddle.concat([g.pos for g in batch]) + graph.shape = [ + len(batch), + ] + return graph + elif isinstance(sample, np.ndarray): batch = np.stack(batch, axis=0) return batch elif isinstance(sample, (paddle.Tensor, paddle.framework.core.eager.Tensor)): diff --git a/ppsci/solver/eval.py b/ppsci/solver/eval.py index 95a2832d1..4cac41b6d 100644 --- a/ppsci/solver/eval.py +++ b/ppsci/solver/eval.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time from typing import TYPE_CHECKING from typing import Dict from typing import Tuple +from typing import Union import paddle from paddle import io @@ -26,13 +29,41 @@ # from ppsci.utils import profiler if TYPE_CHECKING: + from pgl.utils import data as pgl_data + from ppsci import solver +def _get_datset_length( + data_loader: Union["io.DataLoader", "pgl_data.Dataloader", "io.IterableDataset"] +) -> int: + """Get full dataset length of given dataloader. + + Args: + data_loader (Union[io.DataLoader, pgl_data.Dataloader, io.IterableDataset]): + Given dataloader. + + Returns: + int: Length of full dataset. + """ + if isinstance(data_loader, io.DataLoader): + num_samples = len(data_loader.dataset) + elif isinstance(data_loader, io.IterableDataset): + num_samples = data_loader.num_samples + elif str(type(data_loader)) == "": + num_samples = len(data_loader.dataset) + else: + raise NotImplementedError( + f"Can not fetch the length of given dataset({type(data_loader)})." + ) + + return num_samples + + def _eval_by_dataset( solver: "solver.Solver", epoch_id: int, log_freq: int ) -> Tuple[float, Dict[str, Dict[str, float]]]: - """Evaluate with computing metric on total samples. + """Evaluate with computing metric on total samples(default progress). Args: solver (solver.Solver): Main Solver. @@ -48,10 +79,7 @@ def _eval_by_dataset( all_input = misc.Prettydefaultdict(list) all_output = misc.Prettydefaultdict(list) all_label = misc.Prettydefaultdict(list) - if isinstance(_validator.data_loader, io.DataLoader): - num_samples = len(_validator.data_loader.dataset) - else: - num_samples = _validator.data_loader.num_samples + num_samples = _get_datset_length(_validator.data_loader) loss_dict = misc.Prettydefaultdict(float) reader_tic = time.perf_counter() @@ -87,18 +115,24 @@ def _eval_by_dataset( for key, input in input_dict.items(): all_input[key].append( input.detach() + if hasattr(input, "detach") + else input if solver.world_size == 1 else misc.all_gather(input.detach()) ) for key, output in output_dict.items(): all_output[key].append( output.detach() + if hasattr(output, "detach") + else output if solver.world_size == 1 else misc.all_gather(output.detach()) ) for key, label in label_dict.items(): all_label[key].append( label.detach() + if hasattr(label, "detach") + else label if solver.world_size == 1 else misc.all_gather(label.detach()) ) @@ -122,15 +156,18 @@ def _eval_by_dataset( # concate all data and discard padded sample(s) for key in all_input: - all_input[key] = paddle.concat(all_input[key]) + if paddle.is_tensor(all_input[key]): + all_input[key] = paddle.concat(all_input[key]) if len(all_input[key]) > num_samples: all_input[key] = all_input[key][:num_samples] for key in all_output: - all_output[key] = paddle.concat(all_output[key]) + if paddle.is_tensor(all_input[key]): + all_output[key] = paddle.concat(all_output[key]) if len(all_output[key]) > num_samples: all_output[key] = all_output[key][:num_samples] for key in all_label: - all_label[key] = paddle.concat(all_label[key]) + if paddle.is_tensor(all_input[key]): + all_label[key] = paddle.concat(all_label[key]) if len(all_label[key]) > num_samples: all_label[key] = all_label[key][:num_samples] @@ -174,10 +211,7 @@ def _eval_by_batch( """ target_metric: float = None for _, _validator in solver.validator.items(): - if isinstance(_validator.data_loader, io.DataLoader): - num_samples = len(_validator.data_loader.dataset) - else: - num_samples = _validator.data_loader.num_samples + num_samples = _get_datset_length(_validator.data_loader) loss_dict = misc.Prettydefaultdict(float) metric_dict_group = misc.PrettyOrderedDict() diff --git a/ppsci/solver/train.py b/ppsci/solver/train.py index a251e668b..26de4a11f 100644 --- a/ppsci/solver/train.py +++ b/ppsci/solver/train.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time from typing import TYPE_CHECKING diff --git a/ppsci/utils/misc.py b/ppsci/utils/misc.py index bf4298860..a26fc3818 100644 --- a/ppsci/utils/misc.py +++ b/ppsci/utils/misc.py @@ -18,6 +18,7 @@ import functools import os import random +import time from typing import Callable from typing import Dict from typing import List @@ -29,6 +30,8 @@ from matplotlib import pyplot as plt from paddle import distributed as dist +from ppsci.utils import logger + __all__ = [ "all_gather", "AverageMeter", @@ -44,6 +47,7 @@ "run_on_eval_mode", "run_at_rank0", "plot_curve", + "Timer", ] @@ -381,3 +385,38 @@ def plot_curve( plt.savefig(os.path.join(output_dir, f"{xlabel}-{ylabel}_curve.jpg")) plt.clf() + + +class Timer: + """Count time cost for code block within context. + + Args: + name (str, optional): Name of timer discriminate different code block. + Defaults to "Timer". + auto_print (bool, optional): Whether print time cost when exit context. + Defaults to True. + + Examples: + >>> import paddle + >>> from ppsci.utils import misc + >>> with misc.Timer(auto_print=False) as timer: + ... w = sum(range(0, 10)) + >>> print(f"time cost of 'sum(range(0, 10))' is {timer.interval:.2f}") + time cost of 'sum(range(0, 10))' is 0.00 + """ + + interval: float # Time cost for code within Timer context + + def __init__(self, name: str = "Timer", auto_print: bool = True): + self.name = name + self.auto_print = auto_print + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, type, value, traceback): + self.end_time = time.perf_counter() + self.interval = self.end_time - self.start_time + if self.auto_print: + logger.message(f"{self.name}.time_cost = {self.interval:.2f} s")