From abc6c22a7979b4d30ee817a99767daaf8a007434 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sat, 11 May 2024 12:27:13 +0000 Subject: [PATCH 1/6] fully support detach_keys argument for all PDE --- ppsci/equation/pde/base.py | 59 ++++++++++++++++++++----- ppsci/equation/pde/biharmonic.py | 2 + ppsci/equation/pde/heat_exchanger.py | 2 + ppsci/equation/pde/laplace.py | 2 + ppsci/equation/pde/linear_elasticity.py | 2 + ppsci/equation/pde/navier_stokes.py | 2 + ppsci/equation/pde/nls_m_b.py | 2 + ppsci/equation/pde/normal_dot_vec.py | 2 + ppsci/equation/pde/poisson.py | 2 + ppsci/equation/pde/viv.py | 2 + ppsci/utils/symbolic.py | 5 +++ 11 files changed, 70 insertions(+), 12 deletions(-) diff --git a/ppsci/equation/pde/base.py b/ppsci/equation/pde/base.py index 9ef55712a3..96edbe0d61 100644 --- a/ppsci/equation/pde/base.py +++ b/ppsci/equation/pde/base.py @@ -22,7 +22,7 @@ from typing import Union import paddle -import sympy +import sympy as sp from paddle import nn DETACH_FUNC_NAME = "detach" @@ -33,7 +33,7 @@ class PDE: def __init__(self): super().__init__() - self.equations = {} + self.equations: Dict[str, Union[Callable, sp.Basic]] = {} # for PDE which has learnable parameter(s) self.learnable_parameters = nn.ParameterList() @@ -42,7 +42,7 @@ def __init__(self): @staticmethod def create_symbols( symbol_str: str, - ) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]: + ) -> Union[sp.Symbol, Tuple[sp.Symbol, ...]]: """create symbolic variables. Args: @@ -61,11 +61,9 @@ def create_symbols( >>> print(symbols_xyz) (x, y, z) """ - return sympy.symbols(symbol_str) + return sp.symbols(symbol_str) - def create_function( - self, name: str, invars: Tuple[sympy.Symbol, ...] - ) -> sympy.Function: + def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Function: """Create named function depending on given invars. Args: @@ -86,14 +84,51 @@ def create_function( >>> print(f) f(x, y, z) """ - expr = sympy.Function(name)(*invars) + expr = sp.Function(name)(*invars) - # wrap `expression(...)` to `detach(expression(...))` - # if name of expression is in given detach_keys - if self.detach_keys and name in self.detach_keys: - expr = sympy.Function(DETACH_FUNC_NAME)(expr) return expr + def _apply_detach(self): + if self.detach_keys is None: + return + + from copy import deepcopy + + from sympy.core.traversal import postorder_traversal + + from ppsci.utils.symbolic import _cvt_to_key + + for name, expr in self.equations.items(): + if not isinstance(expr, sp.Basic): + continue + # only process sympy expression + expr_ = deepcopy(expr) + for item in postorder_traversal(expr): + if _cvt_to_key(item) in self.detach_keys: + # inplace all related sub_expr into detach(sub_expr) + expr_ = expr_.replace(item, sp.Function(DETACH_FUNC_NAME)(item)) + + # remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping + expr_ = expr_.replace( + sp.Function(DETACH_FUNC_NAME)( + sp.Function(DETACH_FUNC_NAME)(item) + ), + sp.Function(DETACH_FUNC_NAME)(item), + ) + + # remove unccessary detach wrapping for the first arg of Derivative + for item_ in list(postorder_traversal(expr_)): + if isinstance(item_, sp.Derivative): + if item_.args[0].name == DETACH_FUNC_NAME: + expr_ = expr_.replace( + item_, + sp.Derivative( + item_.args[0].args[0], *item_.args[1:] + ), + ) + + self.equations[name] = expr_ + def add_equation(self, name: str, equation: Callable): """Add an equation. diff --git a/ppsci/equation/pde/biharmonic.py b/ppsci/equation/pde/biharmonic.py index 1471c34a6c..933888ac60 100644 --- a/ppsci/equation/pde/biharmonic.py +++ b/ppsci/equation/pde/biharmonic.py @@ -70,3 +70,5 @@ def __init__( biharmonic += u.diff(invar_i, 2).diff(invar_j, 2) self.add_equation("biharmonic", biharmonic) + + self._apply_detach() diff --git a/ppsci/equation/pde/heat_exchanger.py b/ppsci/equation/pde/heat_exchanger.py index d9fd93c224..c2e0107ff3 100644 --- a/ppsci/equation/pde/heat_exchanger.py +++ b/ppsci/equation/pde/heat_exchanger.py @@ -90,3 +90,5 @@ def __init__( self.add_equation("heat_boundary", heat_boundary) self.add_equation("cold_boundary", cold_boundary) self.add_equation("wall", wall) + + self._apply_detach() diff --git a/ppsci/equation/pde/laplace.py b/ppsci/equation/pde/laplace.py index 12b2a03ddd..b99d7c8d9a 100644 --- a/ppsci/equation/pde/laplace.py +++ b/ppsci/equation/pde/laplace.py @@ -51,3 +51,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None): laplace += u.diff(invar, 2) self.add_equation("laplace", laplace) + + self._apply_detach() diff --git a/ppsci/equation/pde/linear_elasticity.py b/ppsci/equation/pde/linear_elasticity.py index 9120c6d21c..44833f56bf 100644 --- a/ppsci/equation/pde/linear_elasticity.py +++ b/ppsci/equation/pde/linear_elasticity.py @@ -179,3 +179,5 @@ def __init__( self.add_equation("traction_y", traction_y) if self.dim == 3: self.add_equation("traction_z", traction_z) + + self._apply_detach() diff --git a/ppsci/equation/pde/navier_stokes.py b/ppsci/equation/pde/navier_stokes.py index 41cb819bf9..c0d3d193a2 100644 --- a/ppsci/equation/pde/navier_stokes.py +++ b/ppsci/equation/pde/navier_stokes.py @@ -147,3 +147,5 @@ def __init__( self.add_equation("momentum_y", momentum_y) if self.dim == 3: self.add_equation("momentum_z", momentum_z) + + self._apply_detach() diff --git a/ppsci/equation/pde/nls_m_b.py b/ppsci/equation/pde/nls_m_b.py index 97bf60cabb..3db2984268 100644 --- a/ppsci/equation/pde/nls_m_b.py +++ b/ppsci/equation/pde/nls_m_b.py @@ -97,3 +97,5 @@ def __init__( self.add_equation("Maxwell_1", Maxwell_1) self.add_equation("Maxwell_2", Maxwell_2) self.add_equation("Bloch", Bloch) + + self._apply_detach() diff --git a/ppsci/equation/pde/normal_dot_vec.py b/ppsci/equation/pde/normal_dot_vec.py index de97a140fb..a6f3942eeb 100644 --- a/ppsci/equation/pde/normal_dot_vec.py +++ b/ppsci/equation/pde/normal_dot_vec.py @@ -55,3 +55,5 @@ def __init__( normal_dot_vec += normal * vec self.add_equation("normal_dot_vec", normal_dot_vec) + + self._apply_detach() diff --git a/ppsci/equation/pde/poisson.py b/ppsci/equation/pde/poisson.py index e83fecde05..4f9551a23a 100644 --- a/ppsci/equation/pde/poisson.py +++ b/ppsci/equation/pde/poisson.py @@ -49,3 +49,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None): poisson += p.diff(invar, 2) self.add_equation("poisson", poisson) + + self._apply_detach() diff --git a/ppsci/equation/pde/viv.py b/ppsci/equation/pde/viv.py index 68fd61a446..c3d85895f1 100644 --- a/ppsci/equation/pde/viv.py +++ b/ppsci/equation/pde/viv.py @@ -60,3 +60,5 @@ def __init__(self, rho: float, k1: float, k2: float): k2 = self.create_symbols(self.k2.name) f = self.rho * eta.diff(t_f, 2) + sp.exp(k1) * eta.diff(t_f) + sp.exp(k2) * eta self.add_equation("f", f) + + self._apply_detach() diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index 8cf368c8c4..02d89242d9 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -40,6 +40,7 @@ __all__ = [ "lambdify", + "_cvt_to_key", ] @@ -116,6 +117,9 @@ def _cvt_to_key(expr: sp.Basic) -> str: Returns: str: Converted string key. """ + if isinstance(expr, sp.Function) and expr.name == equation.DETACH_FUNC_NAME: + return f"{_cvt_to_key(expr.args[0])}_{equation.DETACH_FUNC_NAME}" + if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)): if hasattr(expr, "name"): # use name of custom function instead of itself. @@ -815,6 +819,7 @@ def _expr_to_callable_nodes( elif isinstance(node, sp.Function): if node.name == equation.DETACH_FUNC_NAME: callable_nodes.append(DetachNode(node)) + logger.debug(f"Detected detach node {node}") else: match_index = None for j, model in enumerate(models): From d764d818c07949cfe722148261a96d43d8f045fa Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sat, 11 May 2024 12:44:43 +0000 Subject: [PATCH 2/6] add unitest for detach option --- test/equation/test_detach.py | 173 +++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 test/equation/test_detach.py diff --git a/test/equation/test_detach.py b/test/equation/test_detach.py new file mode 100644 index 0000000000..3e8f3b8047 --- /dev/null +++ b/test/equation/test_detach.py @@ -0,0 +1,173 @@ +import numpy as np +import paddle + +import ppsci +from ppsci import arch +from ppsci import equation + + +def test_equation_detach(): + # use N-S equation for test + all_items = [ + "u", + "u__x", + "u__y", + "u__x__x", + "v", + "v__x", + "v__y", + "v__x__x", + "p", + "p__x", + "p__y", + ] + model1 = arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16) + model2 = arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16) + input_data = { + "x": paddle.randn([16, 1]), + "y": paddle.randn([16, 1]), + } + input_data["x"].stop_gradient = False + input_data["y"].stop_gradient = False + for ii, in_state in enumerate(range(0, 1 << len(all_items), 5)): + detach_keys = [ + item for i, item in enumerate(all_items) if ((1 << i) & in_state) + ] + nu = 3.14 + rho = 0.817 + ns = equation.NavierStokes(nu, rho, 2, False, detach_keys=detach_keys) + model2.set_state_dict(model1.state_dict()) + + exprs = ppsci.lambdify( + list(ns.equations.values()), + model1, + fuse_derivative=False, + ) + for name, f in zip(ns.equations, exprs): + input_data[name] = f(input_data) + + def compute_loss(data_dict): + u = data_dict["u"] + v = data_dict["v"] + p = data_dict["p"] + + u__x = data_dict["u__x"] + u__y = data_dict["u__y"] + u__x__x = data_dict["u__x__x"] + u__y__y = data_dict["u__y__y"] + + v = data_dict["v"] + v__x = data_dict["v__x"] + v__y = data_dict["v__y"] + v__x__x = data_dict["v__x__x"] + v__y__y = data_dict["v__y__y"] + + p = data_dict["p"] + p__x = data_dict["p__x"] + p__y = data_dict["p__y"] + + if "u" in detach_keys: + u = u.detach() + if "v" in detach_keys: + v = v.detach() + if "p" in detach_keys: + p = p.detach() + if "u__x" in detach_keys: + u__x = u__x.detach() + if "u__y" in detach_keys: + u__y = u__y.detach() + if "u__x__x" in detach_keys: + u__x__x = u__x__x.detach() + if "u__y__y" in detach_keys: + u__y__y = u__y__y.detach() + if "v__x" in detach_keys: + v__x = v__x.detach() + if "v__y" in detach_keys: + v__y = v__y.detach() + if "v__x__x" in detach_keys: + v__x__x = v__x__x.detach() + if "v__y__y" in detach_keys: + v__y__y = v__y__y.detach() + if "p__x" in detach_keys: + p__x = p__x.detach() + if "p__y" in detach_keys: + p__y = p__y.detach() + + # continuity + continuity = u__x + v__y + # momentum_x + momentum_x = ( + u * u__x + v * u__y - nu * (u__x__x + u__y__y) + (1 / rho) * p__x + ) + # momentum_y + momentum_y = ( + u * v__x + v * v__y - nu * (v__x__x + v__y__y) + (1 / rho) * p__y + ) + + return ( + (continuity**2).sum() + + (momentum_x**2).sum() + + (momentum_y**2).sum() + ) + + loss1 = compute_loss(input_data) + + loss1.backward() + + ppsci.autodiff.clear() + + input_data = { + "x": input_data["x"], + "y": input_data["y"], + } + x, y = input_data["x"], input_data["y"] + t = model2(input_data) + u, v, p = t["u"], t["v"], t["p"] + + u__x = ppsci.autodiff.jacobian(u, x) + u__y = ppsci.autodiff.jacobian(u, y) + u__x__x = ppsci.autodiff.hessian(u, x) + u__y__y = ppsci.autodiff.hessian(u, y) + + v__x = ppsci.autodiff.jacobian(v, x) + v__y = ppsci.autodiff.jacobian(v, y) + v__x__x = ppsci.autodiff.hessian(v, x) + v__y__y = ppsci.autodiff.hessian(v, y) + + p__x = ppsci.autodiff.jacobian(p, x) + p__y = ppsci.autodiff.jacobian(p, y) + + loss2 = compute_loss( + { + "u": u, + "v": v, + "p": p, + "u__x": u__x, + "u__y": u__y, + "u__x__x": u__x__x, + "u__y__y": u__y__y, + "v__x": v__x, + "v__y": v__y, + "v__x__x": v__x__x, + "v__y__y": v__y__y, + "p__x": p__x, + "p__y": p__y, + } + ) + loss2.backward() + + np.testing.assert_allclose(loss1.numpy(), loss2.numpy()) + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + if (p1.grad is None) ^ (p2.grad is None): + raise AssertionError() + if p1.grad is not None and p2.grad is not None: + np.testing.assert_allclose(p1.grad.numpy(), p2.grad.numpy()) + + ppsci.autodiff.clear() + model1.clear_gradients() + model2.clear_gradients() + + +if __name__ == "__main__": + test_equation_detach() From 08dc68706ba6a14e1e8eb485601e0dcd41d5ae95 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sat, 11 May 2024 14:50:05 +0000 Subject: [PATCH 3/6] fix access for 'name' when exp do not have 'name' attribute --- ppsci/utils/symbolic.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index 02d89242d9..dafc69c59d 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -117,17 +117,15 @@ def _cvt_to_key(expr: sp.Basic) -> str: Returns: str: Converted string key. """ - if isinstance(expr, sp.Function) and expr.name == equation.DETACH_FUNC_NAME: + if isinstance(expr, sp.Function) and str(expr.func) == equation.DETACH_FUNC_NAME: return f"{_cvt_to_key(expr.args[0])}_{equation.DETACH_FUNC_NAME}" if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)): - if hasattr(expr, "name"): - # use name of custom function instead of itself. - return expr.name - else: - return str(expr) + # use name of custom function(e.g. "f") instead of itself(e.g. "f(x, y)") + # for simplicity. + return str(expr.func) elif isinstance(expr, sp.Derivative): - # convert Derivative(u(x,y),(x,2),(y,2)) to "u__x__x__y__y" + # convert "Derivative(u(x,y),(x,2),(y,2))" to "u__x__x__y__y" expr_str = expr.args[0].name for symbol, order in expr.args[1:]: expr_str += f"__{symbol}" * order @@ -817,13 +815,13 @@ def _expr_to_callable_nodes( else: callable_nodes.append(OperatorNode(node)) elif isinstance(node, sp.Function): - if node.name == equation.DETACH_FUNC_NAME: + if str(node.func) == equation.DETACH_FUNC_NAME: callable_nodes.append(DetachNode(node)) logger.debug(f"Detected detach node {node}") else: match_index = None for j, model in enumerate(models): - if str(node.func.name) in model.output_keys: + if str(node.func) in model.output_keys: callable_nodes.append( LayerNode( node, @@ -833,13 +831,13 @@ def _expr_to_callable_nodes( if match_index is not None: raise ValueError( f"Name of function: '{node}' should be unique along given" - f" models, but got same output_key: '{node.func.name}' " + f" models, but got same output_key: '{str(node.func)}' " f"in given models[{match_index}] and models[{j}]." ) match_index = j # NOTE: Skip 'sdf' function, which should be already generated in # given data_dict - if match_index is None and node.name != "sdf": + if match_index is None and str(node.func) != "sdf": raise ValueError( f"Node {node} can not match any model in given model(s)." ) From 2d50bc4fb22b56517303af207fb0931a495c1ccf Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sat, 11 May 2024 16:05:03 +0000 Subject: [PATCH 4/6] fix unitest --- ppsci/utils/symbolic.py | 7 +++++-- test/equation/test_detach.py | 21 ++++++++++----------- test/utils/test_symbolic.py | 9 ++++++--- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index dafc69c59d..2ceb09d341 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -123,7 +123,10 @@ def _cvt_to_key(expr: sp.Basic) -> str: if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)): # use name of custom function(e.g. "f") instead of itself(e.g. "f(x, y)") # for simplicity. - return str(expr.func) + if hasattr(expr, "name"): + return expr.name + else: + return str(expr) elif isinstance(expr, sp.Derivative): # convert "Derivative(u(x,y),(x,2),(y,2))" to "u__x__x__y__y" expr_str = expr.args[0].name @@ -928,7 +931,7 @@ def _expr_to_callable_nodes( logger.debug( f"Fused {len(candidate_pos)} derivatives nodes: " f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into" - f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])" + f" {len(fused_node_seq)} fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])" ) # mark merged node diff --git a/test/equation/test_detach.py b/test/equation/test_detach.py index 3e8f3b8047..c262073baa 100644 --- a/test/equation/test_detach.py +++ b/test/equation/test_detach.py @@ -1,9 +1,8 @@ import numpy as np import paddle +import pytest import ppsci -from ppsci import arch -from ppsci import equation def test_equation_detach(): @@ -21,8 +20,8 @@ def test_equation_detach(): "p__x", "p__y", ] - model1 = arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16) - model2 = arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16) + model1 = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16) + model2 = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16) input_data = { "x": paddle.randn([16, 1]), "y": paddle.randn([16, 1]), @@ -33,15 +32,15 @@ def test_equation_detach(): detach_keys = [ item for i, item in enumerate(all_items) if ((1 << i) & in_state) ] - nu = 3.14 - rho = 0.817 - ns = equation.NavierStokes(nu, rho, 2, False, detach_keys=detach_keys) + nu = 1.314 + rho = 0.156 + ns = ppsci.equation.NavierStokes(nu, rho, 2, False, detach_keys=detach_keys) model2.set_state_dict(model1.state_dict()) exprs = ppsci.lambdify( list(ns.equations.values()), model1, - fuse_derivative=False, + fuse_derivative=True, ) for name, f in zip(ns.equations, exprs): input_data[name] = f(input_data) @@ -156,13 +155,13 @@ def compute_loss(data_dict): ) loss2.backward() - np.testing.assert_allclose(loss1.numpy(), loss2.numpy()) + np.testing.assert_allclose(loss1.numpy(), loss2.numpy(), 0.0, 0.0) for p1, p2 in zip(model1.parameters(), model2.parameters()): if (p1.grad is None) ^ (p2.grad is None): raise AssertionError() if p1.grad is not None and p2.grad is not None: - np.testing.assert_allclose(p1.grad.numpy(), p2.grad.numpy()) + np.testing.assert_allclose(p1.grad.numpy(), p2.grad.numpy(), 1e-5, 1e-5) ppsci.autodiff.clear() model1.clear_gradients() @@ -170,4 +169,4 @@ def compute_loss(data_dict): if __name__ == "__main__": - test_equation_detach() + pytest.main() diff --git a/test/utils/test_symbolic.py b/test/utils/test_symbolic.py index f82a842a56..86468bd43f 100644 --- a/test/utils/test_symbolic.py +++ b/test/utils/test_symbolic.py @@ -85,10 +85,13 @@ def test_multi_model_and_sdf(): tmp2_eval = ep_eval * tmp1_eval * k_eval out_var_reference = tmp1_eval + tmp2_eval - assert np.allclose(out_var_tensor.numpy(), out_var_reference.numpy()) + np.testing.assert_allclose( + out_var_tensor.numpy(), out_var_reference.numpy(), 1e-6, 0.0 + ) def test_complicated_symbolic(): + paddle.seed(2023) x_ten = paddle.randn([32, 1]) x_ten.stop_gradient = False y_ten = paddle.randn([32, 1]) @@ -136,13 +139,13 @@ def random_derivative(state): eqs_expected = ppsci.lambdify( targets, model_f, - fuse_derivative=True, + fuse_derivative=False, ) for i in range(len(targets)): output_fuse = eqs_fuse[i](input_data) output_expected = eqs_expected[i](input_data) - assert np.allclose(output_fuse.numpy(), output_expected.numpy()) + np.testing.assert_allclose(output_fuse.numpy(), output_expected.numpy()) ppsci.autodiff.clear() From da37a03e8814341decaaeefdc5c4483400db1a7f Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 12 May 2024 03:50:25 +0000 Subject: [PATCH 5/6] add example code for _apply_detach --- ppsci/equation/pde/base.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/ppsci/equation/pde/base.py b/ppsci/equation/pde/base.py index 96edbe0d61..b5affbcf75 100644 --- a/ppsci/equation/pde/base.py +++ b/ppsci/equation/pde/base.py @@ -89,6 +89,28 @@ def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Functi return expr def _apply_detach(self): + """ + Wrap detached sub_expr into detach(sub_expr) to prevent gradient + back-propagation, only for those items speicified in self.detach_keys. + + NOTE: This function is expected to be called after self.equations is ready in PDE.__init__. + + Examples: + >>> import ppsci + >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False) + >>> print(ns) + NavierStokes + continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y) + momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2)) + momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2)) + >>> detach_keys = ("u", "v__y") + >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys) + >>> print(ns) + NavierStokes + continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x) + momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2)) + momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2)) + """ if self.detach_keys is None: return @@ -145,7 +167,8 @@ def add_equation(self, name: str, equation: Callable): >>> equation = sympy.diff(u, x) + sympy.diff(u, y) >>> pde.add_equation('linear_pde', equation) >>> print(pde) - PDE, linear_pde: 2*x + 2*y + PDE + linear_pde: 2*x + 2*y """ self.equations.update({name: equation}) @@ -216,7 +239,7 @@ def set_state_dict( return self.learnable_parameters.set_state_dict(state_dict) def __str__(self): - return ", ".join( + return "\n".join( [self.__class__.__name__] - + [f"{name}: {eq}" for name, eq in self.equations.items()] + + [f" {name}: {eq}" for name, eq in self.equations.items()] ) From 9b9887650b42cce2466dd694616fcf62ae31f00b Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 12 May 2024 04:39:15 +0000 Subject: [PATCH 6/6] fix test_pde_base --- test/equation/test_pde_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/equation/test_pde_base.py b/test/equation/test_pde_base.py index 98b01a519f..eb9265ea74 100644 --- a/test/equation/test_pde_base.py +++ b/test/equation/test_pde_base.py @@ -129,7 +129,7 @@ def simple_equation(out): pde.add_equation("simple", simple_equation) - assert str(pde).startswith("PDE, simple: ") + assert str(pde).startswith("PDE\n simple: ") if __name__ == "__main__":