11# Allen-Cahn
22
3- <!-- < a href="TODO " class="md-button md-button--primary" style>AI Studio快速体验</a> -- >
3+ <a href =" https://aistudio.baidu.com/projectdetail/7927786 " class =" md-button md-button--primary " style >AI Studio快速体验</a >
44
55=== "模型训练命令"
66
77 ``` sh
8- python allen_cahn_default.py
8+ # linux
9+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
10+ # windows
11+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
12+ python allen_cahn_piratenet.py
913 ```
1014
1115=== "模型评估命令"
1216
1317 ``` sh
14- python allen_cahn_default.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/allen_cahn/allen_cahn_default_pretrained.pdparams
18+ # linux
19+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
20+ # windows
21+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
22+ python allen_cahn_piratenet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams
1523 ```
1624
1725=== "模型导出命令"
1826
1927 ``` sh
20- python allen_cahn_default .py mode=export
28+ python allen_cahn_piratenet .py mode=export
2129 ```
2230
2331=== "模型推理命令"
2432
2533 ``` sh
26- python allen_cahn_default.py mode=infer
34+ # linux
35+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
36+ # windows
37+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
38+ python allen_cahn_piratenet.py mode=infer
2739 ```
2840
2941| 预训练模型 | 指标 |
3042| :--| :--|
31- | [ allen_cahn_default_pretrained .pdparams] ( TODO ) | TODO |
43+ | [ allen_cahn_piratenet_pretrained .pdparams] ( https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams ) | L2Rel.u: 8.32403e-06 |
3244
3345## 1. 背景简介
3446
7284### 3.1 模型构建
7385
7486在 Allen-Cahn 问题中,每一个已知的坐标点 $(t, x)$ 都有对应的待求解的未知量 $(u)$,
75- ,在这里使用比较简单的 MLP(Multilayer Perceptron, 多层感知机) 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
87+ ,在这里使用 PirateNet 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
7688
7789$$
7890u = f(t, x)
7991$$
8092
81- 上式中 $f$ 即为 MLP 模型本身,用 PaddleScience 代码表示如下
93+ 上式中 $f$ 即为 PirateNet 模型本身,用 PaddleScience 代码表示如下
8294
8395``` py linenums="63"
8496-- 8 < --
85- examples/ allen_cahn/ allen_cahn_default .py:63 :64
97+ examples/ allen_cahn/ allen_cahn_piratenet .py:63 :64
8698-- 8 < --
8799```
88100
89101为了在计算时,准确快速地访问具体变量的值,在这里指定网络模型的输入变量名是 ` ("t", "x") ` ,输出变量名是 ` ("u") ` ,这些命名与后续代码保持一致。
90102
91- 接着通过指定 MLP 的层数、神经元个数,就实例化出了一个拥有 4 层隐藏神经元,每层神经元数为 256 的神经网络模型 ` model ` ,使用 ` tanh ` 作为激活函数。
103+ 接着通过指定 PirateNet 的层数、神经元个数,就实例化出了一个拥有 3 个 PiraBlock,每个 PiraBlock 的隐层神经元个数为 256 的神经网络模型 ` model ` , 并且使用 ` tanh ` 作为激活函数。
92104
93- ``` yaml linenums="35 "
105+ ``` yaml linenums="34 "
94106--8<--
95- examples/allen_cahn/conf/allen_cahn_default .yaml:35:41
107+ examples/allen_cahn/conf/allen_cahn_piratenet .yaml:34:40
96108--8<--
97109```
98110
@@ -102,7 +114,7 @@ Allen-Cahn 微分方程可以用如下代码表示:
102114
103115``` py linenums="66"
104116-- 8 < --
105- examples/ allen_cahn/ allen_cahn_default .py:66 :67
117+ examples/ allen_cahn/ allen_cahn_piratenet .py:66 :67
106118-- 8 < --
107119```
108120
@@ -112,7 +124,7 @@ examples/allen_cahn/allen_cahn_default.py:66:67
112124
113125``` py linenums="69"
114126-- 8 < --
115- examples/ allen_cahn/ allen_cahn_default .py:69 :81
127+ examples/ allen_cahn/ allen_cahn_piratenet .py:69 :81
116128-- 8 < --
117129```
118130
@@ -124,7 +136,7 @@ examples/allen_cahn/allen_cahn_default.py:69:81
124136
125137``` py linenums="94"
126138-- 8 < --
127- examples/ allen_cahn/ allen_cahn_default .py:94 :110
139+ examples/ allen_cahn/ allen_cahn_piratenet .py:94 :110
128140-- 8 < --
129141```
130142
@@ -139,11 +151,11 @@ examples/allen_cahn/allen_cahn_default.py:94:110
139151#### 3.4.2 周期边界约束
140152
141153此处我们采用 hard-constraint 的方式,在神经网络模型中,对输入数据使用cos、sin等周期函数进行周期化,从而让$u_ {\theta}$在数学上直接满足方程的周期性质。
142- 根据方程可得函数$u(t, x)$在$x$轴上的周期为2 ,因此将该周期设置到模型配置里即可。
154+ 根据方程可得函数$u(t, x)$在$x$轴上的周期为 2 ,因此将该周期设置到模型配置里即可。
143155
144- ``` yaml linenums="35 "
156+ ``` yaml linenums="41 "
145157--8<--
146- examples/allen_cahn/conf/allen_cahn_default .yaml:35:43
158+ examples/allen_cahn/conf/allen_cahn_piratenet .yaml:41:42
147159--8<--
148160```
149161
@@ -153,25 +165,25 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:35:43
153165
154166``` py linenums="112"
155167-- 8 < --
156- examples/ allen_cahn/ allen_cahn_default .py:112 :125
168+ examples/ allen_cahn/ allen_cahn_piratenet .py:112 :125
157169-- 8 < --
158170```
159171
160172在微分方程约束、初值约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。
161173
162174``` py linenums="126"
163175-- 8 < --
164- examples/ allen_cahn/ allen_cahn_default .py:126 :130
176+ examples/ allen_cahn/ allen_cahn_piratenet .py:126 :130
165177-- 8 < --
166178```
167179
168180### 3.5 超参数设定
169181
170- 接下来需要指定训练轮数和学习率,此处按实验经验,使用 200 轮训练轮数,0.001 的初始学习率。
182+ 接下来需要指定训练轮数和学习率,此处按实验经验,使用 300 轮训练轮数,0.001 的初始学习率。
171183
172- ``` yaml linenums="51 "
184+ ``` yaml linenums="50 "
173185--8<--
174- examples/allen_cahn/conf/allen_cahn_default .yaml:51:73
186+ examples/allen_cahn/conf/allen_cahn_piratenet .yaml:50:63
175187--8<--
176188```
177189
@@ -181,7 +193,7 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:51:73
181193
182194``` py linenums="132"
183195-- 8 < --
184- examples/ allen_cahn/ allen_cahn_default .py:132 :136
196+ examples/ allen_cahn/ allen_cahn_piratenet .py:132 :136
185197-- 8 < --
186198```
187199
@@ -191,7 +203,7 @@ examples/allen_cahn/allen_cahn_default.py:132:136
191203
192204``` py linenums="138"
193205-- 8 < --
194- examples/ allen_cahn/ allen_cahn_default .py:138 :156
206+ examples/ allen_cahn/ allen_cahn_piratenet .py:138 :156
195207-- 8 < --
196208```
197209
@@ -201,15 +213,15 @@ examples/allen_cahn/allen_cahn_default.py:138:156
201213
202214``` py linenums="158"
203215-- 8 < --
204- examples/ allen_cahn/ allen_cahn_default .py:158 :194
216+ examples/ allen_cahn/ allen_cahn_piratenet .py:158 :184
205217-- 8 < --
206218```
207219
208220## 4. 完整代码
209221
210- ``` py linenums="1" title="allen_cahn_default .py"
222+ ``` py linenums="1" title="allen_cahn_piratenet .py"
211223-- 8 < --
212- examples/ allen_cahn/ allen_cahn_default .py
224+ examples/ allen_cahn/ allen_cahn_piratenet .py
213225-- 8 < --
214226```
215227
@@ -218,12 +230,13 @@ examples/allen_cahn/allen_cahn_default.py
218230在计算域上均匀采样出 $201\times501$ 个点,其预测结果和解析解如下图所示。
219231
220232<figure markdown >
221- ![ allen_cahn_default .jpg] ( https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_default .png ) { loading=lazy }
233+ ![ allen_cahn_piratenet .jpg] ( https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_piratenet_ac .png ) { loading=lazy }
222234 <figcaption > 左侧为 PaddleScience 预测结果,中间为解析解结果,右侧为两者的差值</figcaption >
223235</figure >
224236
225237可以看到对于函数$u(t, x)$,模型的预测结果和解析解的结果基本一致。
226238
227239## 6. 参考资料
228240
241+ - [ PIRATENETS: PHYSICS-INFORMED DEEP LEARNING WITHRESIDUAL ADAPTIVE NETWORKS] ( https://arxiv.org/pdf/2402.00326.pdf )
229242- [ Allen-Cahn equation] ( https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/examples/allen_cahn/README.md )
0 commit comments