2222from typing import Union
2323
2424import paddle
25- import sympy
25+ import sympy as sp
2626from paddle import nn
2727
2828DETACH_FUNC_NAME = "detach"
@@ -33,7 +33,7 @@ class PDE:
3333
3434 def __init__ (self ):
3535 super ().__init__ ()
36- self .equations = {}
36+ self .equations : Dict [ str , Union [ Callable , sp . Basic ]] = {}
3737 # for PDE which has learnable parameter(s)
3838 self .learnable_parameters = nn .ParameterList ()
3939
@@ -42,7 +42,7 @@ def __init__(self):
4242 @staticmethod
4343 def create_symbols (
4444 symbol_str : str ,
45- ) -> Union [sympy .Symbol , Tuple [sympy .Symbol , ...]]:
45+ ) -> Union [sp .Symbol , Tuple [sp .Symbol , ...]]:
4646 """create symbolic variables.
4747
4848 Args:
@@ -61,11 +61,9 @@ def create_symbols(
6161 >>> print(symbols_xyz)
6262 (x, y, z)
6363 """
64- return sympy .symbols (symbol_str )
64+ return sp .symbols (symbol_str )
6565
66- def create_function (
67- self , name : str , invars : Tuple [sympy .Symbol , ...]
68- ) -> sympy .Function :
66+ def create_function (self , name : str , invars : Tuple [sp .Symbol , ...]) -> sp .Function :
6967 """Create named function depending on given invars.
7068
7169 Args:
@@ -86,14 +84,73 @@ def create_function(
8684 >>> print(f)
8785 f(x, y, z)
8886 """
89- expr = sympy .Function (name )(* invars )
87+ expr = sp .Function (name )(* invars )
9088
91- # wrap `expression(...)` to `detach(expression(...))`
92- # if name of expression is in given detach_keys
93- if self .detach_keys and name in self .detach_keys :
94- expr = sympy .Function (DETACH_FUNC_NAME )(expr )
9589 return expr
9690
91+ def _apply_detach (self ):
92+ """
93+ Wrap detached sub_expr into detach(sub_expr) to prevent gradient
94+ back-propagation, only for those items speicified in self.detach_keys.
95+
96+ NOTE: This function is expected to be called after self.equations is ready in PDE.__init__.
97+
98+ Examples:
99+ >>> import ppsci
100+ >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False)
101+ >>> print(ns)
102+ NavierStokes
103+ continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y)
104+ momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
105+ momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
106+ >>> detach_keys = ("u", "v__y")
107+ >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys)
108+ >>> print(ns)
109+ NavierStokes
110+ continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x)
111+ momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
112+ momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
113+ """
114+ if self .detach_keys is None :
115+ return
116+
117+ from copy import deepcopy
118+
119+ from sympy .core .traversal import postorder_traversal
120+
121+ from ppsci .utils .symbolic import _cvt_to_key
122+
123+ for name , expr in self .equations .items ():
124+ if not isinstance (expr , sp .Basic ):
125+ continue
126+ # only process sympy expression
127+ expr_ = deepcopy (expr )
128+ for item in postorder_traversal (expr ):
129+ if _cvt_to_key (item ) in self .detach_keys :
130+ # inplace all related sub_expr into detach(sub_expr)
131+ expr_ = expr_ .replace (item , sp .Function (DETACH_FUNC_NAME )(item ))
132+
133+ # remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping
134+ expr_ = expr_ .replace (
135+ sp .Function (DETACH_FUNC_NAME )(
136+ sp .Function (DETACH_FUNC_NAME )(item )
137+ ),
138+ sp .Function (DETACH_FUNC_NAME )(item ),
139+ )
140+
141+ # remove unccessary detach wrapping for the first arg of Derivative
142+ for item_ in list (postorder_traversal (expr_ )):
143+ if isinstance (item_ , sp .Derivative ):
144+ if item_ .args [0 ].name == DETACH_FUNC_NAME :
145+ expr_ = expr_ .replace (
146+ item_ ,
147+ sp .Derivative (
148+ item_ .args [0 ].args [0 ], * item_ .args [1 :]
149+ ),
150+ )
151+
152+ self .equations [name ] = expr_
153+
97154 def add_equation (self , name : str , equation : Callable ):
98155 """Add an equation.
99156
@@ -110,7 +167,8 @@ def add_equation(self, name: str, equation: Callable):
110167 >>> equation = sympy.diff(u, x) + sympy.diff(u, y)
111168 >>> pde.add_equation('linear_pde', equation)
112169 >>> print(pde)
113- PDE, linear_pde: 2*x + 2*y
170+ PDE
171+ linear_pde: 2*x + 2*y
114172 """
115173 self .equations .update ({name : equation })
116174
@@ -181,7 +239,7 @@ def set_state_dict(
181239 return self .learnable_parameters .set_state_dict (state_dict )
182240
183241 def __str__ (self ):
184- return ", " .join (
242+ return "\n " .join (
185243 [self .__class__ .__name__ ]
186- + [f"{ name } : { eq } " for name , eq in self .equations .items ()]
244+ + [f" { name } : { eq } " for name , eq in self .equations .items ()]
187245 )
0 commit comments