Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hackathon 5th No.58 A physics-informed deep neural network for surrogate modeling in classical elasto-plasticity #558

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/zh/api/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- L1Loss
- L2Loss
- L2RelLoss
- MAELoss
- MSELoss
- MSELossWithL2Decay
- IntegralLoss
Expand Down
141 changes: 141 additions & 0 deletions docs/zh/examples/epnn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# EPNN

=== "模型训练命令"

``` sh
# linux
wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/epnn/dstate-16-plas.dat -O datasets/epnn/dstate-16-plas.dat
wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/epnn/dstress-16-plas.dat -O datasets/epnn/dstress-16-plas.dat
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/epnn/dstate-16-plas.dat --output datasets/epnn/dstate-16-plas.dat
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/epnn/dstress-16-plas.dat --output datasets/epnn/dstress-16-plas.dat
python epnn.py
```

=== "模型评估命令"

``` sh
wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/epnn/dstate-16-plas.dat -O datasets/epnn/dstate-16-plas.dat
wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/epnn/dstress-16-plas.dat -O datasets/epnn/dstress-16-plas.dat
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/epnn/dstate-16-plas.dat --output datasets/epnn/dstate-16-plas.dat
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/epnn/dstress-16-plas.dat --output datasets/epnn/dstress-16-plas.dat
python epnn.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/epnn/epnn_pretrained.pdparams
```

## 1. 背景简介

在这项工作中,我们提出了一种能够高效逼近经典弹塑性本构关系的深度神经网络架构。该网络富含经典弹塑性的关键物理方面,包括应变添加剂分解为弹性和塑性部分,以及非线性增量弹性。这导致了一个名为Elasto-Plastic Neural Network (EPNN)的Physics-Informed Neural Network (PINN)代理模型。详细的分析表明,将这些物理嵌入神经网络的架构中,可以更有效地训练网络,同时使用更少的数据进行训练,同时增强对训练数据外加载制度的推断能力。EPNN的架构是模型和材料无关的,即它可以适应各种弹塑性材料类型,包括地质材料和金属;并且实验数据可以直接用于训练网络。为了证明所提出架构的稳健性,我们将其一般框架应用于砂土的弹塑性行为。我们使用基于相对先进的基于流变性的颗粒材料本构模型的材料点模拟生成的合成数据来训练神经网络。EPNN在预测不同初始密度砂土的未观测应变控制加载路径方面优于常规神经网络架构。

## 2. 问题定义

在神经网络中,信息通过由连接的神经元流动。神经网络中每个链接的“强度”是由一个可变的权重决定的:

$$
z_l^{\mathrm{i}}=W_{k l}^{\mathrm{i}-1, \mathrm{i}} a_k^{\mathrm{i}-1}+b^{\mathrm{i}-1}, \quad k=1: N^{\mathrm{i}-1} \quad \text { or } \quad \mathbf{z}^{\mathrm{i}}=\mathbf{a}^{\mathrm{i}-1} \mathbf{W}^{\mathrm{i}-1, \mathrm{i}}+b^{\mathrm{i}-1} \mathbf{I}
$$

其中b是偏置项;N为不同层中神经元数量;I指的是所有元素都为1的单位向量。

## 3. 问题求解

接下来开始讲解如何将问题一步一步地转化为 PaddleScience 代码,用深度学习的方法求解该问题。
为了快速理解 PaddleScience,接下来仅对模型构建、方程构建、计算域构建等关键步骤进行阐述,而其余细节请参考 [API文档](../api/arch.md)。

### 3.1 模型构建

在 EPNN 问题中,建立网络,用 PaddleScience 代码表示如下

``` py linenums="341"
--8<--
examples/epnn/functions.py:341:361
--8<--
```

Epnn 参数 input_keys 是输入字段名,output_keys 是输出字段名,node_sizes 是节点大小列表,activations 是激活函数字符串列表,drop_p 是 nn.Dropout 中的 p 参数。

### 3.2 数据构建

本案例涉及读取数据构建,如下所示

``` py linenums="36"
--8<--
examples/epnn/epnn.py:36:41
--8<--
```

### 3.3 约束构建

设置训练数据集和损失计算函数,返回字段,代码如下所示:

``` py linenums="83"
--8<--
examples/epnn/epnn.py:83:103
--8<--
```

### 3.4 评估器构建

设置评估数据集和损失计算函数,返回字段,代码如下所示:

``` py linenums="106"
--8<--
examples/epnn/epnn.py:106:128
--8<--
```

### 3.5 超参数设定

接下来我们需要指定训练轮数,此处我们按实验经验,使用 10000 轮训练轮数。

``` yaml linenums="39"
--8<--
examples/epnn/conf/epnn.yaml:39:39
--8<--
```

### 3.6 优化器构建

训练过程会调用优化器来更新模型参数,此处选择 `Adam` 优化器并设定 `learning_rate`。

``` py linenums="366"
--8<--
examples/epnn/functions.py:366:412
--8<--
```

### 3.7 模型训练与评估

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`。

``` py linenums="131"
--8<--
examples/epnn/epnn.py:131:144
--8<--
```

最后启动训练即可:

``` py linenums="147"
--8<--
examples/epnn/epnn.py:147:147
--8<--
```

## 4. 完整代码

``` py linenums="1" title="epnn.py"
--8<--
examples/epnn/epnn.py
--8<--
```

## 5. 结果展示

EPNN 案例针对 epoch=10000 的参数配置进行了实验,结果返回Loss为 0.00471。

## 6. 参考资料

- [A physics-informed deep neural network for surrogate
modeling in classical elasto-plasticity](https://arxiv.org/abs/2204.12088)
- <https://github.com/meghbali/ANNElastoplasticity>
45 changes: 45 additions & 0 deletions examples/epnn/conf/epnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_epnn/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchaned
config:
override_dirname:
exclude_keys:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- mode
- output_dir
- log_freq
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}

# set working condition
DATASET_STATE: datasets/epnn/dstate-16-plas.dat
DATASET_STRESS: datasets/epnn/dstress-16-plas.dat
NTRAIN_SIZE: 40

# model settings
MODEL:
ihlayers: 3
ineurons: 60

# training settings
TRAIN:
epochs: 10000
iters_per_epoch: 1
save_freq: 50
eval_during_train: true
eval_with_no_grad: true
pretrained_model_path: null
checkpoint_path: null
165 changes: 165 additions & 0 deletions examples/epnn/epnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Reference: https://github.com/meghbali/ANNElastoplasticity
"""

from os import path as osp

import functions
import hydra
import paddle
from omegaconf import DictConfig

import ppsci
from ppsci.utils import logger


def train(cfg: DictConfig):
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(cfg.seed)
# initialize logger
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

(
input_dict_train,
label_dict_train,
input_dict_val,
label_dict_val,
) = functions.get_data(cfg.DATASET_STATE, cfg.DATASET_STRESS, cfg.NTRAIN_SIZE)
model_list = functions.get_model_list(
cfg.MODEL.ihlayers,
cfg.MODEL.ineurons,
input_dict_train["state_x"][0].shape[1],
input_dict_train["state_y"][0].shape[1],
input_dict_train["stress_x"][0].shape[1],
)
optimizer_list = functions.get_optimizer_list(
model_list, cfg.TRAIN.epochs, cfg.TRAIN.iters_per_epoch
)
model_state1, model_state2, model_stress = model_list
model_list_obj = ppsci.arch.ModelList(model_list)

def transform_f(input, model, out_key):
input11 = model(input)[out_key]
co63oc marked this conversation as resolved.
Show resolved Hide resolved
input11 = input11.detach().clone()
input_transformed = {}
for key in input:
input_transformed[key] = paddle.squeeze(input[key], axis=0)
input1m = paddle.concat(
x=(
input11,
paddle.index_select(
input_transformed["state_x"],
paddle.to_tensor([0, 1, 2, 3, 7, 8, 9, 10, 11, 12]),
axis=1,
),
),
axis=1,
)
input_transformed["state_x_f"] = input1m
return input_transformed

def transform_f_stress(_in):
return transform_f(_in, model_state1, "u")

model_state1.register_input_transform(functions.transform_in)
model_state2.register_input_transform(functions.transform_in)
model_stress.register_input_transform(transform_f_stress)
model_stress.register_output_transform(functions.transform_out)

sup_constraint_pde = ppsci.constraint.SupervisedConstraint(
{
"dataset": {
"name": "NamedArrayDataset",
"input": input_dict_train,
"label": label_dict_train,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码运行报错,原因是NamedArrayDataset的输入类型不对
image
ppsci.utils.reader.load_dat_file()读进来的dict中value还是np.darray类型,在传到这里变成了List[Tensor,...]
functions中对数据的处理是否不需要一定转成tensor类型?
list转array可以用np.array(list)

},
"num_workers": 0,
},
ppsci.loss.FunctionalLoss(functions.train_loss_func),
{
"state_x": lambda out: out["state_x"],
"state_y": lambda out: out["state_y"],
"stress_x": lambda out: out["stress_x"],
"stress_y": lambda out: out["stress_y"],
"out_state1": lambda out: out["u"],
"out_state2": lambda out: out["v"],
"out_stress": lambda out: out["w"],
},
name="sup_train",
)
constraint_pde = {sup_constraint_pde.name: sup_constraint_pde}

sup_validator_pde = ppsci.validate.SupervisedValidator(
{
"dataset": {
"name": "NamedArrayDataset",
"input": input_dict_val,
"label": label_dict_val,
co63oc marked this conversation as resolved.
Show resolved Hide resolved
},
"num_workers": 0,
},
ppsci.loss.FunctionalLoss(functions.eval_loss_func),
{
"state_x": lambda out: out["state_x"],
"state_y": lambda out: out["state_y"],
"stress_x": lambda out: out["stress_x"],
"stress_y": lambda out: out["stress_y"],
"out_state1": lambda out: out["u"],
"out_state2": lambda out: out["v"],
"out_stress": lambda out: out["w"],
},
metric={"metric": ppsci.metric.FunctionalMetric(functions.metric_expr)},
name="sup_valid",
)
validator_pde = {sup_validator_pde.name: sup_validator_pde}

functions.OUTPUT_DIR = cfg.output_dir
# initialize solver
solver = ppsci.solver.Solver(
model_list_obj,
constraint_pde,
cfg.output_dir,
optimizer_list,
None,
cfg.TRAIN.epochs,
cfg.TRAIN.iters_per_epoch,
save_freq=cfg.TRAIN.save_freq,
eval_during_train=cfg.TRAIN.eval_during_train,
validator=validator_pde,
eval_with_no_grad=cfg.TRAIN.eval_with_no_grad,
)

# train model
solver.train()


def evaluate(cfg: DictConfig):
print("Not supported.")


@hydra.main(version_base=None, config_path="./conf", config_name="epnn.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
main()
Loading