-
Notifications
You must be signed in to change notification settings - Fork 184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【SCU】【PPSCI Export&Infer No.10】 cfdgcn #1037
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,4 +53,4 @@ repos: | |
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ | ||
|
||
exclude: | | ||
^jointContribution/ | ||
^jointContribution/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 末尾换行不要删除 |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -2,6 +2,20 @@ | |||
|
||||
<a href="https://aistudio.baidu.com/projectdetail/7127446" class="md-button md-button--primary" style>AI Studio快速体验</a> | ||||
|
||||
=== "模型导出命令" | ||||
|
||||
``` sh | ||||
python cfdgcn.py mode=export | ||||
``` | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
|
||||
=== "模型推理命令" | ||||
|
||||
``` sh | ||||
python cfdgcn.py mode=infer | ||||
``` | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
|
||||
=== "模型训练命令" | ||||
|
||||
``` sh | ||||
|
@@ -23,7 +37,7 @@ | |||
# generalization experiments | ||||
mpirun -np $((BATCH_SIZE+1)) python cfdgcn.py \ | ||||
TRAIN.batch_size=$((BATCH_SIZE)) \ | ||||
TRAIN_DATA_DIR="./data/NACA0012_machsplit_noshock/outputs_train" \ | ||||
TRAIN_DATA_DIR="./data/NACA0012_machsplit_noshock/outputs_train" \ | ||||
TRAIN_MESH_GRAPH_PATH="./data/NACA0012_machsplit_noshock/mesh_fine. su2" \ | ||||
EVAL_DATA_DIR="./data/NACA0012_machsplit_noshock/outputs_test" \ | ||||
EVAL_MESH_GRAPH_PATH="./data/NACA0012_machsplit_noshock/mesh_fine.su2" \ | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -157,6 +157,90 @@ def train(cfg: DictConfig): | |
) | ||
|
||
|
||
def export(cfg: DictConfig): | ||
# set model | ||
model = ppsci.arch.CFDGCN( | ||
**cfg.MODEL, | ||
process_sim=process_sim, | ||
fine_marker_dict=fine_marker_dict, | ||
su2_module=su2paddle.SU2Module, | ||
) | ||
|
||
solver = ppsci.solver.Solver( | ||
model, pretrained_model_path=cfg.EXPORT.pretrained_model_path | ||
) | ||
|
||
# export model | ||
from paddle.static import InputSpec | ||
|
||
input_spec = [ | ||
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, | ||
] | ||
Comment on lines
+176
to
+178
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cfdgcn的输入是[None, 1]的形状吗?感觉这里不太对 |
||
|
||
solver.export(input_spec, cfg.INFER.export_path) | ||
|
||
|
||
def inference(cfg: DictConfig): | ||
|
||
# 初始化预测器 | ||
Comment on lines
+184
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 换行保留,但是中文注释可以删除 |
||
from deploy.python_infer import pinn_predictor | ||
|
||
predictor = pinn_predictor.PINNPredictor(cfg) | ||
|
||
# 设置 dataloader 配置 | ||
infer_dataloader_cfg = { | ||
"dataset": { | ||
"name": "MeshAirfoilDataset", | ||
"input_keys": ("input",), | ||
"label_keys": ("label",), | ||
"data_dir": cfg.INFER_DATA_DIR, | ||
"mesh_graph_path": cfg.INFER_MESH_GRAPH_PATH, | ||
"transpose_edges": True, | ||
}, | ||
"batch_size": cfg.INFER.batch_size, | ||
"sampler": { | ||
"name": "BatchSampler", | ||
"drop_last": False, | ||
"shuffle": False, | ||
}, | ||
} | ||
|
||
# 初始化数据集和 dataloader | ||
infer_dataset = ppsci.data.MeshAirfoilDataset( | ||
input_keys=infer_dataloader_cfg["dataset"]["input_keys"], | ||
label_keys=infer_dataloader_cfg["dataset"]["label_keys"], | ||
data_dir=infer_dataloader_cfg["dataset"]["data_dir"], | ||
mesh_graph_path=infer_dataloader_cfg["dataset"]["mesh_graph_path"], | ||
transpose_edges=infer_dataloader_cfg["dataset"]["transpose_edges"], | ||
) | ||
infer_dataloader = ppsci.dataloader.DataLoader( | ||
infer_dataset, | ||
batch_size=infer_dataloader_cfg["batch_size"], | ||
shuffle=infer_dataloader_cfg["sampler"]["shuffle"], | ||
drop_last=infer_dataloader_cfg["sampler"]["drop_last"], | ||
num_workers=infer_dataloader_cfg.get("num_workers", 1), | ||
) | ||
|
||
# 进行推理并可视化结果 | ||
with predictor.no_grad_context_manager(True): | ||
for index, (input_dict, label_dict, _) in enumerate(infer_dataloader): | ||
# 获取真实值 | ||
truefield = label_dict["label"].y | ||
# 模型预测 | ||
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) | ||
prefield = output_dict["pred"] | ||
|
||
# 可视化结果 | ||
utils.log_images( | ||
input_dict["input"].pos, | ||
prefield, | ||
truefield, | ||
infer_dataset.elems_list, | ||
index, | ||
"cylinder", | ||
) | ||
|
||
|
||
def evaluate(cfg: DictConfig): | ||
# set dataloader config | ||
train_dataloader_cfg = { | ||
|
@@ -257,8 +341,14 @@ def main(cfg: DictConfig): | |
train(cfg) | ||
elif cfg.mode == "eval": | ||
evaluate(cfg) | ||
elif cfg.mode == "export": | ||
export(cfg) | ||
elif cfg.mode == "infer": | ||
inference(cfg) | ||
else: | ||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||
raise ValueError( | ||
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个pre-commit应该不用改吧