Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add AMGNet example #549

Merged
merged 29 commits into from
Oct 23, 2023

Conversation

HydrogenSulfate
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate commented Sep 22, 2023

PR types

New features

PR changes

Others

Describe

  1. 添加 AMGNet 代码和文档
  2. ppsci.utils.misc 中添加 Timer 计时器上下文管理器功能,便于对代码块进行计时
  3. train.py/eval.py/collate_fn 代码适配 pglDataLoaderpgl.Graph 类型的数据
  4. 添加 MeshAirfoilDataset 数据集

其他:

  1. 优化多处代码逻辑以适配 AMGNet
  2. 文档首页"快速安装"章节样式调整

参考案例: https://aistudio.baidu.com/projectdetail/6779372?contributionType=1

@paddle-bot
Copy link

paddle-bot bot commented Sep 22, 2023

Thanks for your contribution!

@HydrogenSulfate HydrogenSulfate changed the title [WIP] Add AMGNet example [WIP][Example] Add AMGNet example Sep 26, 2023
@HydrogenSulfate HydrogenSulfate changed the title [WIP][Example] Add AMGNet example [Example] Add AMGNet example Sep 27, 2023
matplotlib.use("Agg")


def getcorsenode(latent_graph):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议单词之间使用下划线隔开

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

plt.savefig(out_file)
plt.close()

if show:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

检查一下该if内是否还需要抛出异常

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

return output


def getcorsenode(latent_graph: "pgl.Graph") -> paddle.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

:param segment_ids: The segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

规范docstring

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Then the result is add to output latent graph of encoder and the modified latent graph will be feed into original processor

Option: choose whether to normalize the high rank node connection
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

规范docstring

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

places=device.get_device(),
batch_sampler=sampler,
collate_fn=collate_fn,
num_workers=cfg.get("num_workers", 0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

122行、113行两处的num_workers的默认值是否可以保持一致?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为1

@@ -46,6 +48,8 @@
"LorenzDataset",
"RosslerDataset",
"VtuDataset",
"MeshAirfoilDataset",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保持与import顺序一致

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -46,6 +48,8 @@
"LorenzDataset",
"RosslerDataset",
"VtuDataset",
"MeshAirfoilDataset",
"MeshCylinderDataset",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

for _ in range(num_elems)
]
marker_dict[marker_tag] = marker_elems
if line.startswith("NELEM"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与上面的两个if是互斥关系,可否使用elif代替

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


- AMGNET 通过 RS 算法(Olson and Schroder, 2018)进行了图的粗化,仅使用少量节点即可进行预测,进一步提高了预测速度。

下图为该方法的网络结构图。该模型的基本原理就是将网格结构转化为图结构,然后通过网格中节点的物理信息,位置信息以及节点类型对图中的节点和边进行编码。接着对得到的图神经网络使用基于代数多重网格算法(RS)的粗化层进行粗化,将所有节点分类为粗节点集和细节点集,其中粗节点集是细节点集的子集。粗图的节点集合就是粗节点集,于是完成了图的粗化,缩小了图的规模。粗化完成后通过设计的图神经网络信息传递块(GN)来总结和提取图的特征。之后图恢复层采用反向操作,使用空间插值法(Qi et al.,2017)对图进行上采样。例如要对节点 $i$ 插值,则在粗图中找到距离节点 $i$ 最近的 $k$ 个节点,然后通过公式计算得到节点 $i$ 的特征。最后,通过解码器得到每个节点的速度与压力信息。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"然后通过网格中节点的物理信息,位置信息以及节点类型对图中的节点和边进行编码" --> "然后通过网格中节点的物理信息、位置信息以及节点类型对图中的节点和边进行编码"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,逗号改顿号


### 3.6 评估器构建

在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器,构建过程与 [约束构建](#34) 类似,只需把数据目录改为测试集的目录,并在配置文件中设置 `EVAL.batch_size=1` 即可。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上文中约束构建是3.3 约束构建

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

constraint,
cfg.output_dir,
optimizer,
None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以去掉吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个None参数属于位置参数,不能删除,删了的话代码行数会变动

Copy link
Collaborator

@zhiminzhang0830 zhiminzhang0830 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiminzhang0830 zhiminzhang0830 merged commit 59ca3a1 into PaddlePaddle:develop Oct 23, 2023
4 checks passed
@HydrogenSulfate HydrogenSulfate deleted the add_AMGNet branch October 26, 2023 13:22
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
* update WIP code

* (WIP)update AMGNet code

* try import pgl to avoid importerror

* try import pyamg to avoid importerror

* add airfoil_dataset.py

* add type checking for amgnet

* try import pgl to avoid importerror

* refine Timer

* replace pgl.Dataset with io.Dataset

* update reproded code

* replace ImportError with ModuleNotFoundError

* refine amgnet.py

* refine amgnet_airfoil.py and amgnet_cylinder.py

* refine utils.py

* refine collate_fn

* fix bug in eval.py

* refine codes

* refine codes

* modify atol from 1e-8 to 1e-7 of UT test_navierstokes

* refine code

* add AMGNet document

* fix

* fix

* avoid tensor converion in dataset, and move in to collate_fn

* update final code

* add example for AMGNet

* fix doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants