diff --git a/README.md b/README.md index f5e4e873c8..b35a507c82 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,8 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计 | 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 | |-----|---------|-----|---------|----|---------|---------| -| 一维线性对流问题 | [1D 线性对流](https://paddlescience-docs.readthedocs.io/zh/examples/adv_cvit.md) | 数据驱动 | ViT | 监督学习 | [Data](https://github.com/Zhengyu-Huang/Operator-Learning/tree/main/data) | [Paper](https://arxiv.org/abs/2405.13998) | -| 非定常不可压流体 | [2D 方腔浮力驱动流](https://paddlescience-docs.readthedocs.io/zh/examples/ns_cvit.md) | 数据驱动 | ViT | 监督学习 | [Data](https://huggingface.co/datasets/pdearena/NavierStokes-2D) | [Paper](https://arxiv.org/abs/2405.13998) | +| 一维线性对流问题 | [1D 线性对流](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/adv_cvit/) | 数据驱动 | ViT | 监督学习 | [Data](https://github.com/Zhengyu-Huang/Operator-Learning/tree/main/data) | [Paper](https://arxiv.org/abs/2405.13998) | +| 非定常不可压流体 | [2D 方腔浮力驱动流](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/ns_cvit/) | 数据驱动 | ViT | 监督学习 | [Data](https://huggingface.co/datasets/pdearena/NavierStokes-2D) | [Paper](https://arxiv.org/abs/2405.13998) | | 定常不可压流体 | [Re3200 2D 定常方腔流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/ldc2d_steady) | 机理驱动 | MLP | 无监督学习 | - | | | 定常不可压流体 | [2D 达西流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/darcy2d) | 机理驱动 | MLP | 无监督学习 | - | | | 定常不可压流体 | [2D 管道流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/labelfree_DNN_surrogate) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://arxiv.org/abs/1906.02382) | diff --git a/deploy/python_infer/pinn_predictor.py b/deploy/python_infer/pinn_predictor.py index 4b360e0e7c..c37481d11c 100644 --- a/deploy/python_infer/pinn_predictor.py +++ b/deploy/python_infer/pinn_predictor.py @@ -159,7 +159,7 @@ def predict( ed = min(num_samples, batch_id * batch_size) batch_input_dict = {key: input_dict[key][st:ed] for key in input_dict} else: - batch_input_dict = {key: input_dict[key] for key in input_dict} + batch_input_dict = {**input_dict} # send batch input data to input handle(s) if self.engine != "onnx": diff --git a/docs/zh/api/data/process/batch_transform.md b/docs/zh/api/data/process/batch_transform.md index 91bdcf53bb..54f4c9ac59 100644 --- a/docs/zh/api/data/process/batch_transform.md +++ b/docs/zh/api/data/process/batch_transform.md @@ -1,9 +1,10 @@ # Data.batch_transform(批预处理) 模块 -::: ppsci.data.process.transform +::: ppsci.data.process.batch_transform handler: python options: members: - build_batch_transforms + - FunctionalBatchTransform - default_collate_fn show_root_heading: true diff --git a/docs/zh/development.md b/docs/zh/development.md index 2742949608..bd0468b20f 100644 --- a/docs/zh/development.md +++ b/docs/zh/development.md @@ -926,7 +926,7 @@ PaddleScience 是一个开源的代码库,由多人共同参与开发,因此 PaddleScience 使用了包括 [isort](https://github.com/PyCQA/isort#installing-isort)、[black](https://github.com/psf/black) 等自动化代码检查、格式化插件, 让 commit 的代码遵循 python [PEP8](https://pep8.org/) 代码风格规范。 -因此在 commit 您的代码之前,请务必先执行以下命令安装 `pre-commit`,否则提交的 PR 会被 code-style 检测到代码未格式化而无法合入。 +因此在 commit 您的代码之前,请务必先在 `PaddleScience/` 目录下执行以下命令安装 `pre-commit`,否则提交的 PR 会被 code-style 检测到代码未格式化而无法合入。 ``` sh pip install pre-commit diff --git a/docs/zh/examples/adv_cvit.md b/docs/zh/examples/adv_cvit.md index e7fa22e9dd..dc6fab75ba 100644 --- a/docs/zh/examples/adv_cvit.md +++ b/docs/zh/examples/adv_cvit.md @@ -49,17 +49,22 @@ CVit 作为一种算子学习模型,以输入函数 $u$、函数 $s$ 的查询 本问题求解如下方程: Formulation The 1D advection equation in $\Omega=[0,1)$ is + $$ \begin{aligned} & \frac{\partial u}{\partial t}+c \frac{\partial u}{\partial x}=0 \quad x \in \Omega, \\ & u(0)=u_0 \end{aligned} $$ + where $c=1$ is the constant advection speed, and periodic boundary conditions are imposed. We are interested in the map from the initial $u_0$ to solution $u(\cdot, T)$ at $T=0.5$. The initial condition $u_0$ is assumed to be + $$ u_0=-1+2 \mathbb{1}\left\{\tilde{u_0} \geq 0\right\} $$ + where $\widetilde{u_0}$ a centered Gaussian + $$ \widetilde{u_0} \sim \mathbb{N}(0, \mathrm{C}) \quad \text { and } \quad \mathrm{C}=\left(-\Delta+\tau^2\right)^{-d} \text {; } $$ diff --git a/docs/zh/examples/ns_cvit.md b/docs/zh/examples/ns_cvit.md index f256a026ca..08316402d2 100644 --- a/docs/zh/examples/ns_cvit.md +++ b/docs/zh/examples/ns_cvit.md @@ -54,6 +54,7 @@ CVit 作为一种算子学习模型,以输入函数 $u$、函数 $s$ 的查询 本问题基于固定方腔的不可压 buoyancy-driven flow 即方腔内的浮力驱动流动问题,求解如下方程: Formulation We consider the vorticity-stream $(\omega-\psi)$ formulation of the incompressible Navier-Stokes equations on a two-dimensional periodic domain, $D=D_u=D_v=[0,2 \pi]^2$ : + $$ \begin{aligned} & \frac{\partial \omega}{\partial t}+(v \cdot \nabla) \omega-v \Delta \omega=f^{\prime} \\ diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index 4ff980568b..6a4085fdf1 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -710,7 +710,7 @@ def predict( self, input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]], expr_dict: Optional[Dict[str, Callable]] = None, - batch_size: int = 64, + batch_size: Optional[int] = 64, no_grad: bool = True, return_numpy: bool = False, ) -> Dict[str, Union[paddle.Tensor, np.ndarray]]: @@ -720,7 +720,9 @@ def predict( input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict. expr_dict (Optional[Dict[str, Callable]]): Expression dict, which guide to compute equation variable with callable function. Defaults to None. - batch_size (int, optional): Predicting by batch size. Defaults to 64. + batch_size (Optional[int]): Predicting by batch size. If None, data in + `input_dict` will be used directly for inference without any batch slicing. + Defaults to 64. no_grad (bool): Whether set stop_gradient=True for entire prediction, mainly for memory-efficiency. Defaults to True. return_numpy (bool): Whether convert result from Tensor to numpy ndarray. @@ -773,26 +775,32 @@ def predict( if self.world_size > 1 else input_dict ) - local_batch_num = (local_num_samples_pad + (batch_size - 1)) // batch_size + local_batch_num = ( + (local_num_samples_pad + (batch_size - 1)) // batch_size + if batch_size is not None + else 1 + ) pred_dict = misc.Prettydefaultdict(list) with self.no_grad_context_manager(no_grad), self.no_sync_context_manager( self.world_size > 1, self.model ): for batch_id in range(local_batch_num): - batch_input_dict = {} - st = batch_id * batch_size - ed = min(local_num_samples_pad, (batch_id + 1) * batch_size) - # prepare batch input dict - for key in local_input_dict: - if not paddle.is_tensor(local_input_dict[key]): - batch_input_dict[key] = paddle.to_tensor( - local_input_dict[key][st:ed], paddle.get_default_dtype() - ) - else: - batch_input_dict[key] = local_input_dict[key][st:ed] - batch_input_dict[key].stop_gradient = no_grad + batch_input_dict = {} + if batch_size is not None: + st = batch_id * batch_size + ed = min(local_num_samples_pad, (batch_id + 1) * batch_size) + for key in local_input_dict: + if not paddle.is_tensor(local_input_dict[key]): + batch_input_dict[key] = paddle.to_tensor( + local_input_dict[key][st:ed], paddle.get_default_dtype() + ) + else: + batch_input_dict[key] = local_input_dict[key][st:ed] + batch_input_dict[key].stop_gradient = no_grad + else: + batch_input_dict = {**local_input_dict} # forward with self.autocast_context_manager(self.use_amp, self.amp_level): diff --git a/ppsci/utils/config.py b/ppsci/utils/config.py index 7550340407..e3eb5c2221 100644 --- a/ppsci/utils/config.py +++ b/ppsci/utils/config.py @@ -42,7 +42,7 @@ class EMAConfig(BaseModel): def decay_check(cls, v): if v <= 0 or v >= 1: raise ValueError( - f"'decay' should be in (0, 1) when is type of float, but got {v}" + f"'ema.decay' should be in (0, 1) when is type of float, but got {v}" ) return v @@ -50,7 +50,7 @@ def decay_check(cls, v): def avg_freq_check(cls, v): if v <= 0: raise ValueError( - "'avg_freq' should be a positive integer when is type of int, " + "'ema.avg_freq' should be a positive integer when is type of int, " f"but got {v}" ) return v @@ -63,15 +63,17 @@ class SWAConfig(BaseModel): @field_validator("avg_range") def avg_range_check(cls, v, info: ValidationInfo): if isinstance(v, tuple) and v[0] > v[1]: - raise ValueError(f"'avg_range' should be a valid range, but got {v}.") + raise ValueError( + f"'swa.avg_range' should be a valid range, but got {v}." + ) if isinstance(v, tuple) and v[0] < 0: raise ValueError( - "The start epoch of 'avg_range' should be a non-negtive integer" + "The start epoch of 'swa.avg_range' should be a non-negtive integer" f" , but got {v[0]}." ) if isinstance(v, tuple) and v[1] > info.data["epochs"]: raise ValueError( - "The end epoch of 'avg_range' should not be lager than " + "The end epoch of 'swa.avg_range' should not be lager than " f"'epochs'({info.data['epochs']}), but got {v[1]}." ) return v @@ -80,7 +82,7 @@ def avg_range_check(cls, v, info: ValidationInfo): def avg_freq_check(cls, v): if v <= 0: raise ValueError( - "'avg_freq' should be a positive integer when is type of int, " + "'swa.avg_freq' should be a positive integer when is type of int, " f"but got {v}" ) return v @@ -107,7 +109,7 @@ class TrainConfig(BaseModel): def epochs_check(cls, v): if v <= 0: raise ValueError( - "'epochs' should be a positive integer when is type of int, " + "'TRAIN.epochs' should be a positive integer when is type of int, " f"but got {v}" ) return v @@ -116,7 +118,7 @@ def epochs_check(cls, v): def iters_per_epoch_check(cls, v): if v <= 0: raise ValueError( - "'iters_per_epoch' should be a positive integer when is type of int" + "'TRAIN.iters_per_epoch' should be a positive integer when is type of int" f", but got {v}" ) return v @@ -125,7 +127,7 @@ def iters_per_epoch_check(cls, v): def update_freq_check(cls, v): if v <= 0: raise ValueError( - "'update_freq' should be a positive integer when is type of int" + "'TRAIN.update_freq' should be a positive integer when is type of int" f", but got {v}" ) return v @@ -134,7 +136,7 @@ def update_freq_check(cls, v): def save_freq_check(cls, v): if v < 0: raise ValueError( - "'save_freq' should be a non-negtive integer when is type of int" + "'TRAIN.save_freq' should be a non-negtive integer when is type of int" f", but got {v}" ) return v @@ -144,8 +146,8 @@ def start_eval_epoch_check(cls, v, info: ValidationInfo): if info.data["eval_during_train"]: if v <= 0: raise ValueError( - f"'start_eval_epoch' should be a positive integer when " - f"'eval_during_train' is True, but got {v}" + f"'TRAIN.start_eval_epoch' should be a positive integer when " + f"'TRAIN.eval_during_train' is True, but got {v}" ) return v @@ -154,8 +156,8 @@ def eval_freq_check(cls, v, info: ValidationInfo): if info.data["eval_during_train"]: if v <= 0: raise ValueError( - f"'eval_freq' should be a positive integer when " - f"'eval_during_train' is True, but got {v}" + f"'TRAIN.eval_freq' should be a positive integer when " + f"'TRAIN.eval_during_train' is True, but got {v}" ) return v @@ -176,6 +178,15 @@ class EvalConfig(BaseModel): pretrained_model_path: Optional[str] = None eval_with_no_grad: bool = False compute_metric_by_batch: bool = False + batch_size: Optional[int] = 256 + + @field_validator("batch_size") + def batch_size_check(cls, v): + if isinstance(v, int) and v <= 0: + raise ValueError( + f"'EVAL.batch_size' should be greater than 0 or None, but got {v}" + ) + return v class InferConfig(BaseModel): """ @@ -203,12 +214,12 @@ class InferConfig(BaseModel): def engine_check(cls, v, info: ValidationInfo): if v == "tensorrt" and info.data["device"] != "gpu": raise ValueError( - "'device' should be 'gpu' when 'engine' is 'tensorrt', " + "'INFER.device' should be 'gpu' when 'INFER.engine' is 'tensorrt', " f"but got '{info.data['device']}'" ) if v == "mkldnn" and info.data["device"] != "cpu": raise ValueError( - "'device' should be 'cpu' when 'engine' is 'mkldnn', " + "'INFER.device' should be 'cpu' when 'INFER.engine' is 'mkldnn', " f"but got '{info.data['device']}'" ) @@ -218,21 +229,25 @@ def engine_check(cls, v, info: ValidationInfo): def min_subgraph_size_check(cls, v): if v <= 0: raise ValueError( - "'min_subgraph_size' should be greater than 0, " f"but got {v}" + "'INFER.min_subgraph_size' should be greater than 0, " + f"but got {v}" ) return v @field_validator("gpu_mem") def gpu_mem_check(cls, v): if v <= 0: - raise ValueError("'gpu_mem' should be greater than 0, " f"but got {v}") + raise ValueError( + "'INFER.gpu_mem' should be greater than 0, " f"but got {v}" + ) return v @field_validator("gpu_id") def gpu_id_check(cls, v): if v < 0: raise ValueError( - "'gpu_id' should be greater than or equal to 0, " f"but got {v}" + "'INFER.gpu_id' should be greater than or equal to 0, " + f"but got {v}" ) return v @@ -240,7 +255,7 @@ def gpu_id_check(cls, v): def max_batch_size_check(cls, v): if v <= 0: raise ValueError( - "'max_batch_size' should be greater than 0, " f"but got {v}" + "'INFER.max_batch_size' should be greater than 0, " f"but got {v}" ) return v @@ -248,16 +263,16 @@ def max_batch_size_check(cls, v): def num_cpu_threads_check(cls, v): if v < 0: raise ValueError( - "'num_cpu_threads' should be greater than or equal to 0, " + "'INFER.num_cpu_threads' should be greater than or equal to 0, " f"but got {v}" ) return v @field_validator("batch_size") def batch_size_check(cls, v): - if v <= 0: + if isinstance(v, int) and v <= 0: raise ValueError( - "'batch_size' should be greater than 0, " f"but got {v}" + f"'INFER.batch_size' should be greater than 0 or None, but got {v}" ) return v @@ -326,7 +341,8 @@ def use_wandb_check(cls, v, info: ValidationInfo): - TRAIN/swa: swa_default <-- 'swa_default' used here - EVAL: eval_default <-- 'eval_default' used here - INFER: infer_default <-- 'infer_default' used here - - _self_ + - _self_ <-- config defined in current yaml + mode: train seed: 42 ... @@ -384,6 +400,7 @@ def use_wandb_check(cls, v, info: ValidationInfo): "EVAL.pretrained_model_path", "EVAL.eval_with_no_grad", "EVAL.compute_metric_by_batch", + "EVAL.batch_size", "INFER.pretrained_model_path", "INFER.export_path", "INFER.pdmodel_path",