-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add AMGNet example #549
[Example] Add AMGNet example #549
Conversation
Thanks for your contribution! |
7b71329
to
38659ca
Compare
b8aedc3
to
1254dc6
Compare
1254dc6
to
66954a6
Compare
c4bc77b
to
031f210
Compare
examples/amgnet/utils.py
Outdated
matplotlib.use("Agg") | ||
|
||
|
||
def getcorsenode(latent_graph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议单词之间使用下划线隔开
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
plt.savefig(out_file) | ||
plt.close() | ||
|
||
if show: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
检查一下该if内是否还需要抛出异常
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
ppsci/arch/amgnet.py
Outdated
return output | ||
|
||
|
||
def getcorsenode(latent_graph: "pgl.Graph") -> paddle.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
:param segment_ids: The segment indices tensor. | ||
:param num_segments: The number of segments. | ||
:return: A tensor of same data type as the data argument. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
规范docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
Then the result is add to output latent graph of encoder and the modified latent graph will be feed into original processor | ||
|
||
Option: choose whether to normalize the high rank node connection | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
规范docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
ppsci/data/__init__.py
Outdated
places=device.get_device(), | ||
batch_sampler=sampler, | ||
collate_fn=collate_fn, | ||
num_workers=cfg.get("num_workers", 0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
122行、113行两处的num_workers的默认值是否可以保持一致?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为1
@@ -46,6 +48,8 @@ | |||
"LorenzDataset", | |||
"RosslerDataset", | |||
"VtuDataset", | |||
"MeshAirfoilDataset", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
保持与import顺序一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -46,6 +48,8 @@ | |||
"LorenzDataset", | |||
"RosslerDataset", | |||
"VtuDataset", | |||
"MeshAirfoilDataset", | |||
"MeshCylinderDataset", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
for _ in range(num_elems) | ||
] | ||
marker_dict[marker_tag] = marker_elems | ||
if line.startswith("NELEM"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
与上面的两个if是互斥关系,可否使用elif代替
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
c672678
to
c3e560a
Compare
3797afe
to
451491a
Compare
451491a
to
3a6e985
Compare
6a41855
to
11c803d
Compare
544e73c
to
98f5a8a
Compare
98f5a8a
to
876ef4d
Compare
docs/zh/examples/amgnet.md
Outdated
|
||
- AMGNET 通过 RS 算法(Olson and Schroder, 2018)进行了图的粗化,仅使用少量节点即可进行预测,进一步提高了预测速度。 | ||
|
||
下图为该方法的网络结构图。该模型的基本原理就是将网格结构转化为图结构,然后通过网格中节点的物理信息,位置信息以及节点类型对图中的节点和边进行编码。接着对得到的图神经网络使用基于代数多重网格算法(RS)的粗化层进行粗化,将所有节点分类为粗节点集和细节点集,其中粗节点集是细节点集的子集。粗图的节点集合就是粗节点集,于是完成了图的粗化,缩小了图的规模。粗化完成后通过设计的图神经网络信息传递块(GN)来总结和提取图的特征。之后图恢复层采用反向操作,使用空间插值法(Qi et al.,2017)对图进行上采样。例如要对节点 $i$ 插值,则在粗图中找到距离节点 $i$ 最近的 $k$ 个节点,然后通过公式计算得到节点 $i$ 的特征。最后,通过解码器得到每个节点的速度与压力信息。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"然后通过网格中节点的物理信息,位置信息以及节点类型对图中的节点和边进行编码" --> "然后通过网格中节点的物理信息、位置信息以及节点类型对图中的节点和边进行编码"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,逗号改顿号
docs/zh/examples/amgnet.md
Outdated
|
||
### 3.6 评估器构建 | ||
|
||
在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器,构建过程与 [约束构建](#34) 类似,只需把数据目录改为测试集的目录,并在配置文件中设置 `EVAL.batch_size=1` 即可。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上文中约束构建是3.3 约束构建
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
constraint, | ||
cfg.output_dir, | ||
optimizer, | ||
None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以去掉吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个None参数属于位置参数,不能删除,删了的话代码行数会变动
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* update WIP code * (WIP)update AMGNet code * try import pgl to avoid importerror * try import pyamg to avoid importerror * add airfoil_dataset.py * add type checking for amgnet * try import pgl to avoid importerror * refine Timer * replace pgl.Dataset with io.Dataset * update reproded code * replace ImportError with ModuleNotFoundError * refine amgnet.py * refine amgnet_airfoil.py and amgnet_cylinder.py * refine utils.py * refine collate_fn * fix bug in eval.py * refine codes * refine codes * modify atol from 1e-8 to 1e-7 of UT test_navierstokes * refine code * add AMGNet document * fix * fix * avoid tensor converion in dataset, and move in to collate_fn * update final code * add example for AMGNet * fix doc
PR types
New features
PR changes
Others
Describe
ppsci.utils.misc
中添加Timer
计时器上下文管理器功能,便于对代码块进行计时pgl
的DataLoader
和pgl.Graph
类型的数据MeshAirfoilDataset
数据集其他:
参考案例: https://aistudio.baidu.com/projectdetail/6779372?contributionType=1