Skip to content

Commit

Permalink
add mathjax expression for equations and fix some code (#361)
Browse files Browse the repository at this point in the history
* Update doc for release/1.0 (#356)

* update links README.md and index.md from latest to release/1.0(test=document_fix)

* update install_setup.md

* update docstrings

* Cherry pick panorama (#360)

* fix panorama.png

* add mathjax expression for equations and fix some code
  • Loading branch information
HydrogenSulfate authored Jun 5, 2023
1 parent c0ae483 commit 6fe232d
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 23 deletions.
2 changes: 1 addition & 1 deletion ppsci/arch/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class MLP(base.Arch):
activation (str, optional): Name of activation function. Defaults to "tanh".
skip_connection (bool, optional): Whether to use skip connection. Defaults to False.
weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False.
input_dim (Optional[int], optional): Number of input's dimension. Defaults to None.
input_dim (Optional[int]): Number of input's dimension. Defaults to None.
Examples:
>>> import ppsci
Expand Down
4 changes: 2 additions & 2 deletions ppsci/autodiff/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ def __call__(
Args:
ys (paddle.Tensor): Output tensor.
xs (paddle.Tensor): Input tensor.
component (Optional[int], optional): If `y` has the shape (batch_size, dim_y > 1), then `y[:, component]`
component (Optional[int]): If `y` has the shape (batch_size, dim_y > 1), then `y[:, component]`
is used to compute the Hessian. Do not use if `y` has the shape (batch_size,
1). Defaults to None.
i (int, optional): i-th input variable. Defaults to 0.
j (int, optional): j-th input variable. Defaults to 0.
grad_y (Optional[paddle.Tensor], optional): The gradient of `y` w.r.t. `xs`. Provide `grad_y` if known to avoid
grad_y (Optional[paddle.Tensor]): The gradient of `y` w.r.t. `xs`. Provide `grad_y` if known to avoid
duplicate computation. Defaults to None.
Returns:
Expand Down
6 changes: 5 additions & 1 deletion ppsci/equation/pde/biharmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@


class Biharmonic(base.PDE):
"""Class for biharmonic equation.
r"""Class for biharmonic equation.
$$
\nabla^4 \varphi = \dfrac{q}{D}
$$
Args:
dim (int): Dimension of equation.
Expand Down
6 changes: 5 additions & 1 deletion ppsci/equation/pde/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@


class Laplace(base.PDE):
"""Class for laplace equation.
r"""Class for laplace equation.
$$
\nabla^2 \varphi = 0
$$
Args:
dim (int): Dimension of equation.
Expand Down
18 changes: 16 additions & 2 deletions ppsci/equation/pde/linear_elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,29 @@


class LinearElasticity(base.PDE):
"""Linear elasticity equations.
r"""Linear elasticity equations.
Use either (E, nu) or (lambda_, mu) to define the material properties.
$$
\begin{cases}
stress\_disp_{xx} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial u}{\partial x} - \sigma_{xx} \\
stress\_disp_{yy} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial v}{\partial y} - \sigma_{yy} \\
stress\_disp_{zz} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial w}{\partial z} - \sigma_{zz} \\
traction_{x} = \mathbf{n}_x \sigma_{xx} + \mathbf{n}_y \sigma_{xy} + \mathbf{n}_z \sigma_{xz} \\
traction_{y} = \mathbf{n}_y \sigma_{yx} + \mathbf{n}_y \sigma_{yy} + \mathbf{n}_z \sigma_{yz} \\
traction_{z} = \mathbf{n}_z \sigma_{zx} + \mathbf{n}_y \sigma_{zy} + \mathbf{n}_z \sigma_{zz} \\
navier_{x} = \rho(\dfrac{\partial^2 u}{\partial t}) - (\lambda + \mu)(\dfrac{\partial^2 u}{\partial x^2}+\dfrac{\partial^2 v}{\partial y \partial x} + \dfrac{\partial^2 w}{\partial z \partial x}) - \mu(\dfrac{\partial^2 u}{\partial x^2} + \dfrac{\partial^2 u}{\partial y^2} + \dfrac{\partial^2 u}{\partial z^2}) \\
navier_{y} = \rho(\dfrac{\partial^2 v}{\partial t}) - (\lambda + \mu)(\dfrac{\partial^2 v}{\partial x \partial y}+\dfrac{\partial^2 v}{\partial y^2} + \dfrac{\partial^2 w}{\partial z \partial y}) - \mu(\dfrac{\partial^2 v}{\partial x^2} + \dfrac{\partial^2 v}{\partial y^2} + \dfrac{\partial^2 v}{\partial z^2}) \\
navier_{z} = \rho(\dfrac{\partial^2 w}{\partial t}) - (\lambda + \mu)(\dfrac{\partial^2 w}{\partial x \partial z}+\dfrac{\partial^2 v}{\partial y \partial z} + \dfrac{\partial^2 w}{\partial z^2}) - \mu(\dfrac{\partial^2 w}{\partial x^2} + \dfrac{\partial^2 w}{\partial y^2} + \dfrac{\partial^2 w}{\partial z^2}) \\
\end{cases}
$$
Args:
E (Optional[float]): The Young's modulus. Defaults to None.
nu (Optional[float]): The Poisson's ratio. Defaults to None.
lambda_ (Optional[float]): Lamé's first parameter. Defaults to None.
mu (Optional[float]): Lamé's second parameter (shear modulus). Defaults to None.
rho (float, optional): Mass density.. Defaults to 1.
rho (float, optional): Mass density. Defaults to 1.
dim (int, optional): Dimension of the linear elasticity (2 or 3). Defaults to 3.
time (bool, optional): Whether contains time data. Defaults to False.
Expand Down
38 changes: 31 additions & 7 deletions ppsci/equation/pde/navier_stokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,34 @@


class NavierStokes(base.PDE):
"""Class for navier-stokes equation.
r"""Class for navier-stokes equation.
$$
\begin{cases}
\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z} = 0 \\
\dfrac{\partial u}{\partial t} + u\dfrac{\partial u}{\partial x} + v\dfrac{\partial u}{\partial y} + w\dfrac{\partial w}{\partial z} =
- \dfrac{1}{\rho}\dfrac{\partial p}{\partial x}
+ \nu(
\dfrac{\partial ^2 u}{\partial x ^2}
+ \dfrac{\partial ^2 u}{\partial y ^2}
+ \dfrac{\partial ^2 u}{\partial z ^2}
) \\
\dfrac{\partial v}{\partial t} + u\dfrac{\partial v}{\partial x} + v\dfrac{\partial v}{\partial y} + w\dfrac{\partial w}{\partial z} =
- \dfrac{1}{\rho}\dfrac{\partial p}{\partial y}
+ \nu(
\dfrac{\partial ^2 v}{\partial x ^2}
+ \dfrac{\partial ^2 v}{\partial y ^2}
+ \dfrac{\partial ^2 v}{\partial z ^2}
) \\
\dfrac{\partial w}{\partial t} + u\dfrac{\partial w}{\partial x} + v\dfrac{\partial w}{\partial y} + w\dfrac{\partial w}{\partial z} =
- \dfrac{1}{\rho}\dfrac{\partial p}{\partial z}
+ \nu(
\dfrac{\partial ^2 w}{\partial x ^2}
+ \dfrac{\partial ^2 w}{\partial y ^2}
+ \dfrac{\partial ^2 w}{\partial z ^2}
) \\
\end{cases}
$$
Args:
nu (float): Dynamic viscosity.
Expand All @@ -43,8 +70,7 @@ def continuity_compute_func(out):
u, v = out["u"], out["v"]
continuity = jacobian(u, x) + jacobian(v, y)
if self.dim == 3:
z = out["z"]
w = out["w"]
z, w = out["z"], out["w"]
continuity += jacobian(w, z)
return continuity

Expand All @@ -64,8 +90,7 @@ def momentum_x_compute_func(out):
t = out["t"]
momentum_x += jacobian(u, t)
if self.dim == 3:
z = out["z"]
w = out["w"]
z, w = out["z"], out["w"]
momentum_x += w * jacobian(u, z)
momentum_x -= nu / rho * hessian(u, z)
return momentum_x
Expand All @@ -86,8 +111,7 @@ def momentum_y_compute_func(out):
t = out["t"]
momentum_y += jacobian(v, t)
if self.dim == 3:
z = out["z"]
w = out["w"]
z, w = out["z"], out["w"]
momentum_y += w * jacobian(v, z)
momentum_y -= nu / rho * hessian(v, z)
return momentum_y
Expand Down
6 changes: 5 additions & 1 deletion ppsci/equation/pde/normal_dot_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@


class NormalDotVec(base.PDE):
"""NormalDotVec.
r"""NormalDotVec.
$$
\mathbf{n} \cdot \mathbf{v} = 0
$$
Args:
velocity_keys (Tuple[str, ...]): Keys for velocity(ies).
Expand Down
6 changes: 5 additions & 1 deletion ppsci/equation/pde/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@


class Poisson(base.PDE):
"""Class for poisson equation.
r"""Class for poisson equation.
$$
\nabla^2 \varphi = C
$$
Args:
dim (int): Dimension of equation.
Expand Down
10 changes: 7 additions & 3 deletions ppsci/equation/pde/viv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@


class Vibration(base.PDE):
"""Vortex induced vibration equation.
r"""Vortex induced vibration equation.
$$
\rho \dfrac{\partial^2 \eta}{\partial t^2} + e^{k1} \dfrac{\partial \eta}{\partial t} + e^{k2} \eta = f
$$
Args:
rho (float): Generalized mass.
Expand All @@ -37,12 +41,12 @@ def __init__(self, rho: float, k1: float, k2: float):
super().__init__()
self.rho = rho
self.k1 = paddle.create_parameter(
shape=[1],
shape=[],
dtype=paddle.get_default_dtype(),
default_initializer=initializer.Constant(k1),
)
self.k2 = paddle.create_parameter(
shape=[1],
shape=[],
dtype=paddle.get_default_dtype(),
default_initializer=initializer.Constant(k2),
)
Expand Down
2 changes: 1 addition & 1 deletion ppsci/geometry/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def scale(self, scale, center=(0, 0, 0)):
open3d.utility.Vector3dVector(vertices),
open3d.utility.Vector3iVector(faces),
)
open3d_mesh.scale(scale, center)
open3d_mesh = open3d_mesh.scale(scale, center)
self.py_mesh = pymesh.form_mesh(
np.asarray(open3d_mesh.vertices, dtype=paddle.get_default_dtype()), faces
)
Expand Down
6 changes: 3 additions & 3 deletions ppsci/loss/integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ class IntegralLoss(base.Loss):
$$
L =
\begin{cases}
\dfrac{1}{N} \Vert \mathbf{s} \circ \mathbf{x} - \mathbf{y} \Vert_2^2, & \text{if reduction='mean'} \\
\Vert \mathbf{s} \circ \mathbf{x} - \mathbf{y} \Vert_2^2, & \text{if reduction='sum'}
\dfrac{1}{N} \Vert \displaystyle\sum_{i=1}^{M}{\mathbf{s}_i \cdot \mathbf{x}_i} - \mathbf{y} \Vert_2^2, & \text{if reduction='mean'} \\
\Vert \displaystyle\sum_{i=0}^{M}{\mathbf{s}_i \cdot \mathbf{x}_i} - \mathbf{y} \Vert_2^2, & \text{if reduction='sum'}
\end{cases}
$$
$$
\mathbf{x}, \mathbf{y}, \mathbf{s} \in \mathcal{R}^{N}
\mathbf{x}, \mathbf{s} \in \mathcal{R}^{M \times N}, \mathbf{y} \in \mathcal{R}^{N}
$$
Args:
Expand Down

0 comments on commit 6fe232d

Please sign in to comment.