Skip to content

Commit

Permalink
Backend paddle: Support deeponet and other examples
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 committed Aug 9, 2023
1 parent 197f298 commit dc16f05
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 57 deletions.
24 changes: 7 additions & 17 deletions deepxde/icbc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,11 @@ class PointSetBC:
component: Integer or a list of integers. The output components satisfying this BC.
List of integers only supported for the backend PyTorch.
batch_size: The number of points per minibatch, or `None` to return all points.
This is only supported for the backend PyTorch.
This is only supported for the backend PyTorch and PaddlePaddle.
shuffle: Randomize the order on each pass through the data when batching.
"""

def __init__(
self, points, values, component=0, batch_size=None, shuffle=True
):
def __init__(self, points, values, component=0, batch_size=None, shuffle=True):
self.points = np.array(points, dtype=config.real(np))
self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib))
self.component = component
Expand All @@ -193,13 +191,11 @@ def __init__(
self.batch_size = batch_size

if batch_size is not None: # batch iterator and state
if backend_name != "pytorch":
if backend_name not in ["pytorch", "paddle"]:
raise RuntimeError(
"batch_size only implemented for pytorch backend"
"batch_size only implemented for pytorch and paddle backend"
)
self.batch_sampler = data.sampler.BatchSampler(
len(self), shuffle=shuffle
)
self.batch_sampler = data.sampler.BatchSampler(len(self), shuffle=shuffle)
self.batch_indices = None

def __len__(self):
Expand All @@ -218,15 +214,9 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
outputs[beg:end, self.component : self.component + 1]
- self.values[self.batch_indices]
)
return (
outputs[beg:end, self.component]
- self.values[self.batch_indices]
)
return outputs[beg:end, self.component] - self.values[self.batch_indices]
if isinstance(self.component, numbers.Number):
return (
outputs[beg:end, self.component : self.component + 1]
- self.values
)
return outputs[beg:end, self.component : self.component + 1] - self.values
# When a concat is provided, the following code works 'fast' in paddle cpu,
# and slow in both tensorflow backends, jax untested.
# tf.gather can be used instead of for loop but is also slow
Expand Down
20 changes: 13 additions & 7 deletions deepxde/nn/paddle/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __init__(
if use_bias:
# register bias to parameter for updating in optimizer and storage
self.b = self.create_parameter(
shape=(1, ),
default_initializer=initializers.get("zeros")
shape=(1,), default_initializer=initializers.get("zeros")
)

def forward(self, inputs):
Expand All @@ -75,8 +74,12 @@ def forward(self, inputs):
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
x = paddle.einsum("bi,bi->b", x_func, x_loc) # [batch_size, ]
x = paddle.reshape(x, [-1, 1]) # reshape [batch_size, ] to [batch_size, 1]
# Use the following formula to temporarily replace paddle.einsum()
# Because no higher(>=2) orderderivatives for the op now
x = paddle.sum(x_func * x_loc, axis=1, keepdim=True)
# TODO:
# x = paddle.einsum("bi,bi->b", x_func, x_loc) # [batch_size, ]
# x = paddle.reshape(x, [-1, 1]) # reshape [batch_size, ] to [batch_size, 1]
# Add bias
if self.use_bias:
x += self.b
Expand Down Expand Up @@ -124,8 +127,7 @@ def __init__(
self.trunk = FNN(layer_sizes_trunk, self.activation_trunk, kernel_initializer)
# register bias to parameter for updating in optimizer and storage
self.b = self.create_parameter(
shape=(1, ),
default_initializer=initializers.get("zeros")
shape=(1,), default_initializer=initializers.get("zeros")
)
self.regularizer = regularization

Expand All @@ -143,7 +145,11 @@ def forward(self, inputs):
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
x = paddle.einsum("bi,ni->bn", x_func, x_loc)
# Use the following formula to temporarily replace paddle.einsum()
# Because no higher(>=2) orderderivatives for the op now
x = x_func @ x_loc.T
# TODO:
# x = paddle.einsum("bi,ni->bn", x_func, x_loc)
# Add bias
x += self.b

Expand Down
25 changes: 19 additions & 6 deletions examples/operator/advection_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
"""Backend supported: tensorflow.compat.v1, tensorflow"""
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
import paddle

dim_x = 5
sin = paddle.sin
cos = paddle.cos
concat = paddle.concat
else:
import tf

dim_x = 2
sin = tf.sin
cos = tf.cos
concat = tf.concat


# PDE
def pde(x, y, v):
Expand Down Expand Up @@ -36,7 +51,7 @@ def func_ic(x, v):
# Net
net = dde.nn.DeepONetCartesianProd(
[50, 128, 128, 128],
[2, 128, 128, 128],
[dim_x, 128, 128, 128],
"tanh",
"Glorot normal",
)
Expand All @@ -45,9 +60,7 @@ def func_ic(x, v):
def periodic(x):
x, t = x[:, :1], x[:, 1:]
x *= 2 * np.pi
return tf.concat(
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
)
return concat([cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1)


net.apply_feature_transform(periodic)
Expand Down
25 changes: 19 additions & 6 deletions examples/operator/advection_aligned_pideeponet_2d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
"""Backend supported: tensorflow.compat.v1, tensorflow"""
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
import paddle

dim_x = 5
sin = paddle.sin
cos = paddle.cos
concat = paddle.concat
else:
import tf

dim_x = 2
sin = tf.sin
cos = tf.cos
concat = tf.concat


# PDE
def pde(x, y, v):
Expand Down Expand Up @@ -41,7 +56,7 @@ def boundary(x, on_boundary):
# Net
net = dde.nn.DeepONetCartesianProd(
[50, 128, 128, 128],
[2, 128, 128, 128],
[dim_x, 128, 128, 128],
"tanh",
"Glorot normal",
)
Expand All @@ -50,9 +65,7 @@ def boundary(x, on_boundary):
def periodic(x):
x, t = x[:, :1], x[:, 1:]
x *= 2 * np.pi
return tf.concat(
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
)
return concat([cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1)


net.apply_feature_transform(periodic)
Expand Down
25 changes: 19 additions & 6 deletions examples/operator/advection_unaligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
import paddle

dim_x = 5
sin = paddle.sin
cos = paddle.cos
concat = paddle.concat
else:
import tf

dim_x = 2
sin = tf.sin
cos = tf.cos
concat = tf.concat


# PDE
def pde(x, y, v):
Expand Down Expand Up @@ -36,7 +51,7 @@ def func_ic(x, v):
# Net
net = dde.nn.DeepONet(
[50, 128, 128, 128],
[2, 128, 128, 128],
[dim_x, 128, 128, 128],
"tanh",
"Glorot normal",
)
Expand All @@ -45,9 +60,7 @@ def func_ic(x, v):
def periodic(x):
x, t = x[:, :1], x[:, 1:]
x *= 2 * np.pi
return tf.concat(
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
)
return concat([cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1)


net.apply_feature_transform(periodic)
Expand Down
25 changes: 19 additions & 6 deletions examples/operator/advection_unaligned_pideeponet_2d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
import paddle

dim_x = 5
sin = paddle.sin
cos = paddle.cos
concat = paddle.concat
else:
import tf

dim_x = 2
sin = tf.sin
cos = tf.cos
concat = tf.concat


# PDE
def pde(x, y, v):
Expand Down Expand Up @@ -39,7 +54,7 @@ def boundary(x, on_boundary):
# Net
net = dde.nn.DeepONet(
[50, 128, 128, 128],
[2, 128, 128, 128],
[dim_x, 128, 128, 128],
"tanh",
"Glorot normal",
)
Expand All @@ -48,9 +63,7 @@ def boundary(x, on_boundary):
def periodic(x):
x, t = x[:, :1], x[:, 1:]
x *= 2 * np.pi
return tf.concat(
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
)
return concat([cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1)


net.apply_feature_transform(periodic)
Expand Down
15 changes: 12 additions & 3 deletions examples/operator/antiderivative_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""Backend supported: tensorflow.compat.v1, tensorflow"""
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
import paddle

transpose = paddle.transpose
else:
import tf

transpose = tf.transpose


dde.config.disable_xla_jit()
Expand Down Expand Up @@ -35,9 +43,10 @@ def pde(x, u, v):
"Glorot normal",
)


# Hard constraint zero IC
def zero_ic(inputs, outputs):
return outputs * tf.transpose(inputs[1])
return outputs * transpose(inputs[1], [1, 0])


net.apply_output_transform(zero_ic)
Expand Down
3 changes: 2 additions & 1 deletion examples/operator/antiderivative_unaligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -30,6 +30,7 @@ def pde(x, u, v):
"Glorot normal",
)


# Hard constraint zero IC
def zero_ic(inputs, outputs):
return outputs * inputs[1]
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/diff_rec_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1, tensorflow"""
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/diff_rec_unaligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion examples/pinn_forward/Burgers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
import deepxde as dde
import numpy as np

Expand Down
4 changes: 2 additions & 2 deletions examples/pinn_inverse/elliptic_inverse_field_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1, pytorch"""
"""Backend supported: tensorflow.compat.v1, pytorch, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -57,7 +57,7 @@ def sol(x):
plt.plot(x, uhat, "--", label="u_NN")
plt.legend()

qtrue = -np.pi ** 2 * np.sin(np.pi * x)
qtrue = -np.pi**2 * np.sin(np.pi * x)
print("l2 relative error for q: " + str(dde.metrics.l2_relative_error(qtrue, qhat)))
plt.figure()
plt.plot(x, qtrue, "-", label="q_true")
Expand Down

0 comments on commit dc16f05

Please sign in to comment.