Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surrogate pinns #338

Merged
merged 35 commits into from
Jun 7, 2023
Merged

Conversation

wangguan1995
Copy link
Contributor

PR types

Others

PR changes

APIs

Describe

Add example of paper [Surrogate modeling for fluid flows based on physics-constrained deep learning without simulation data]

@paddle-bot
Copy link

paddle-bot bot commented May 24, 2023

Thanks for your contribution!

examples/aneurysm/aneurysm_flow.py Outdated Show resolved Hide resolved
examples/aneurysm/aneurysm_flow.py Outdated Show resolved Hide resolved
examples/aneurysm/aneurysm_flow.py Outdated Show resolved Hide resolved
examples/aneurysm/aneurysm_flow.py Outdated Show resolved Hide resolved
examples/aneurysm/aneurysm_flow.py Outdated Show resolved Hide resolved
Comment on lines 113 to 135
class MiniBatchDataset(io.Dataset):
def __init__(self, input, label, weight):
super().__init__()
self.input = input
self.label = label
self.num_samples = self.check_input(input)

def check_input(self, input):
len_input = set()
for _, value in input.items():
len_input.add(len(value))
if len(len_input) is not 1:
raise AttributeError("Input dimension mismatch")
else:
return list(len_input)[0]

def __getitem__(self, idx):
input_item = {key: value[idx] for key, value in self.input.items()}
label_item = {key: value[idx] for key, value in self.label.items()}
return (input_item, label_item)

def __len__(self):
return self.num_samples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_dict返回空字典即可,不需要写一个 MiniBatchDataset

@@ -43,7 +43,7 @@ def log_train_info(trainer, batch_size, epoch_id, iter_id):

metric_msg = ", ".join(
[
f"{key}: {trainer.train_output_info[key].avg:.5f}"
f"{key}: {trainer.train_output_info[key].val:.9f}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调试代码改回去

logger.info(f"[Train][Epoch {epoch_id}/{self.epochs}][Avg] {metric_msg}")
# logger.info(f"[Train][Epoch {epoch_id}/{self.epochs}][Avg] {metric_msg}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调试代码改回去

@@ -72,7 +72,6 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
if solver.update_freq > 1:
total_loss = total_loss / solver.update_freq
loss_dict["loss"] = float(total_loss)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要删除空行

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除调试代码

ppsci/data/__init__.py Outdated Show resolved Hide resolved
examples/aneurysm/aneurysm_flow.py Show resolved Hide resolved
ppsci/arch/activation.py Outdated Show resolved Hide resolved
examples/aneurysm/aneurysm_flow.py Outdated Show resolved Hide resolved
examples/aneurysm/aneurysm_flow.py Outdated Show resolved Hide resolved
examples/aneurysm/aneurysm_flow.py Show resolved Hide resolved
import matplotlib.pyplot as plt
import numpy as np
import paddle
from paddle.fluid import core
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除

bbox_inches="tight",
)
plt.close("all")
print(f"epochs = {EPOCHS}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 ppsci.utils.logger 代替 print

from ppsci.utils import logger
logger.info(...)

examples/pipe/poiseuille_flow.py Show resolved Hide resolved

# set model
model_u = ppsci.arch.MLP(("sin(x)", "cos(x)", "y", "nu"), ("u"), 3, 50, "swish")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除多余空行

Comment on lines 48 to 49
"swish": Swish(),
"silu": silu,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要调换原来 silu 的位置

ppsci/equation/pde/navier_stokes.py Show resolved Hide resolved
Comment on lines 35 to 36
PLOT_DIR = osp.join(OUTPUT_DIR, "visu")
os.makedirs(PLOT_DIR, exist_ok=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量定义尽量靠近使用位置

X_IN = 0
X_OUT = X_IN + L
R_INLET = 0.05
unique_x = np.linspace(X_IN, X_OUT, 100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unique_x 定以后没有用到

examples/aneurysm/aneurysm_flow.py Outdated Show resolved Hide resolved
for x0 in x_inital:
index = np.where(x[:, 0] == x0)[0]
# y is linear to scale, so we place linespace to get 1000 x, it coressponds to vessels
y[index] = np.linspace(-max(y_up[index]), max(y_up[index]), len(index)).reshape(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有np.linspace显式指定dtype

examples/aneurysm/aneurysm_flow.py Show resolved Hide resolved
examples/pipe/poiseuille_flow.py Outdated Show resolved Hide resolved
lw=2.0,
alpha=1.0,
)
nu_current = float("{0:.5f}".format(data_1d_nu[nu_index[idxP]]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.format全部改成f-string

plt.ylabel(r"PDF", fontsize=fontsize)
ax1.tick_params(axis="x", labelsize=fontsize)
ax1.tick_params(axis="y", labelsize=fontsize)
plt.savefig("pipe_unformUQ.png", bbox_inches="tight")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

跟第一个案例一样,新建一个 PLOT_DIR,所有结果保存到这个文件里去,不要直接保存

def __init__(self, beta: float = 1.0):
super().__init__()
self.beta = self.create_parameter(
shape=[1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape=[]

super().__init__()
self.beta = self.create_parameter(
shape=[1],
dtype=paddle.get_default_dtype(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除dtype=...

shape=[],
default_initializer=paddle.nn.initializer.Constant(beta),
)
self.add_parameter("beta", self.beta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要手动add_parameter,删掉

examples/pipe/poiseuille_flow.py Show resolved Hide resolved
examples/pipe/poiseuille_flow.py Show resolved Hide resolved
bbox_inches="tight",
)
plt.close("all")
logger.info(f"epochs = {EPOCHS}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这句话删掉,没必要在末尾打印 EPOCHS

Comment on lines 223 to 229
def single_test(x, y, scale, solver):
xt = paddle.to_tensor(x)
yt = paddle.to_tensor(y)
scalet = paddle.full_like(xt, scale)
input_dict = {"x": xt, "y": yt, "scale": scalet}
output_dict = solver.predict(input_dict, batch_size=100)
return output_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. single_test -> model_predict
  2. numpy的转换代码放到函数里面,return {k: v.numpy() for k, v in output_dict.items()}
  3. 给 model_predict 加上 type hint
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from ppsci import solver
...
...

def model_predict(x: np.ndarray, y: np.ndarray, scale: np.ndarray, solver: "solver.Solver"):
    ...

# See the License for the specific language governing permissions and
# limitations under the License.


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除一个空行

Comment on lines 267 to 268
for i in range(num_test):
u_max_a[i] = (R**2) * dP / (2 * L * data_1d_nu_distribution[i] * RHO)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成 u_max_a = (R**2) * dP / (2 * L * data_1d_nu_distribution * RHO)

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit 839743b into PaddlePaddle:develop Jun 7, 2023
@HydrogenSulfate HydrogenSulfate deleted the surrogate_pinns branch June 7, 2023 11:59
@HydrogenSulfate HydrogenSulfate restored the surrogate_pinns branch September 20, 2023 03:36
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
* fix random and create demo

* Reproduce Fig. 2.(a) successfully

* reproduce fig1 (a) and (b) successfully

* reproduce aneurysm case in paper 3.2

* Aneurysm converged

* clean pipe flow code

* add main

* clean code

* using xvaier instead tf net initial params

* muli opt:w

* need shuffle

* refine by review

* add wss formula

* review fix

* delete useless comment

* reformat pipe flow

* merge new NS eqn which thinks of nu as variable

* reformat transform

* recover mlp code

* delete blank line

* fix by reviews

* fix epochs

* delete useless lines

* bug fix

* forget to uncomment train

* fix multi-learnable bug

* fix by reviews

* fix by reviews

* clean activation

* add reference

* fix by reviews
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants