diff --git a/deepxde/nn/paddle/deeponet.py b/deepxde/nn/paddle/deeponet.py index 57e09d543..f541fcd84 100755 --- a/deepxde/nn/paddle/deeponet.py +++ b/deepxde/nn/paddle/deeponet.py @@ -75,8 +75,12 @@ def forward(self, inputs): raise AssertionError( "Output sizes of branch net and trunk net do not match." ) - x = paddle.einsum("bi,bi->b", x_func, x_loc) # [batch_size, ] - x = paddle.reshape(x, [-1, 1]) # reshape [batch_size, ] to [batch_size, 1] + # Use the following formula to temporarily replace paddle.einsum() + # Because no higher(>=2) orderderivatives for the op now + x = paddle.sum(x_func*x_loc,axis = 1,keepdim=True) + # TODO: + # x = paddle.einsum("bi,bi->b", x_func, x_loc) # [batch_size, ] + # x = paddle.reshape(x, [-1, 1]) # reshape [batch_size, ] to [batch_size, 1] # Add bias if self.use_bias: x += self.b @@ -143,7 +147,11 @@ def forward(self, inputs): raise AssertionError( "Output sizes of branch net and trunk net do not match." ) - x = paddle.einsum("bi,ni->bn", x_func, x_loc) + # Use the following formula to temporarily replace paddle.einsum() + # Because no higher(>=2) orderderivatives for the op now + x = x_func@x_loc.T + # TODO: + # x = paddle.einsum("bi,ni->bn", x_func, x_loc) # Add bias x += self.b diff --git a/examples/operator/advection_aligned_pideeponet.py b/examples/operator/advection_aligned_pideeponet.py index bf94a3b78..e6671f9a5 100644 --- a/examples/operator/advection_aligned_pideeponet.py +++ b/examples/operator/advection_aligned_pideeponet.py @@ -1,8 +1,19 @@ -"""Backend supported: tensorflow.compat.v1, tensorflow""" +"""Backend supported: tensorflow.compat.v1, tensorflow, paddle""" import deepxde as dde import matplotlib.pyplot as plt import numpy as np -from deepxde.backend import tf + +if dde.backend.backend_name == "paddle": + dim_x = 5 + sin = dde.backend.paddle.sin + cos = dde.backend.paddle.cos + concat = dde.backend.paddle.concat +else: + dim_x = 2 + sin = dde.backend.tf.sin + cos = dde.backend.tf.cos + concat = dde.backend.tf.concat + # PDE def pde(x, y, v): @@ -36,7 +47,7 @@ def func_ic(x, v): # Net net = dde.nn.DeepONetCartesianProd( [50, 128, 128, 128], - [2, 128, 128, 128], + [dim_x, 128, 128, 128], "tanh", "Glorot normal", ) @@ -45,8 +56,8 @@ def func_ic(x, v): def periodic(x): x, t = x[:, :1], x[:, 1:] x *= 2 * np.pi - return tf.concat( - [tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1 + return concat( + [cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1 ) diff --git a/examples/operator/advection_aligned_pideeponet_2d.py b/examples/operator/advection_aligned_pideeponet_2d.py index 26e57efa5..533c5e197 100644 --- a/examples/operator/advection_aligned_pideeponet_2d.py +++ b/examples/operator/advection_aligned_pideeponet_2d.py @@ -1,8 +1,19 @@ -"""Backend supported: tensorflow.compat.v1, tensorflow""" +"""Backend supported: tensorflow.compat.v1, tensorflow, paddle""" import deepxde as dde import matplotlib.pyplot as plt import numpy as np -from deepxde.backend import tf + +if dde.backend.backend_name == "paddle": + dim_x = 5 + sin = dde.backend.paddle.sin + cos = dde.backend.paddle.cos + concat = dde.backend.paddle.concat +else: + dim_x = 2 + sin = dde.backend.tf.sin + cos = dde.backend.tf.cos + concat = dde.backend.tf.concat + # PDE def pde(x, y, v): @@ -41,17 +52,16 @@ def boundary(x, on_boundary): # Net net = dde.nn.DeepONetCartesianProd( [50, 128, 128, 128], - [2, 128, 128, 128], + [dim_x, 128, 128, 128], "tanh", "Glorot normal", ) - def periodic(x): x, t = x[:, :1], x[:, 1:] x *= 2 * np.pi - return tf.concat( - [tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1 + return concat( + [cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1 ) diff --git a/examples/operator/advection_unaligned_pideeponet.py b/examples/operator/advection_unaligned_pideeponet.py index fc9a0ac59..5c69ef61e 100644 --- a/examples/operator/advection_unaligned_pideeponet.py +++ b/examples/operator/advection_unaligned_pideeponet.py @@ -1,8 +1,19 @@ -"""Backend supported: tensorflow.compat.v1""" +"""Backend supported: tensorflow.compat.v1, paddle""" import deepxde as dde import matplotlib.pyplot as plt import numpy as np -from deepxde.backend import tf + +if dde.backend.backend_name == "paddle": + dim_x = 5 + sin = dde.backend.paddle.sin + cos = dde.backend.paddle.cos + concat = dde.backend.paddle.concat +else: + dim_x = 2 + sin = dde.backend.tf.sin + cos = dde.backend.tf.cos + concat = dde.backend.tf.concat + # PDE def pde(x, y, v): @@ -36,7 +47,7 @@ def func_ic(x, v): # Net net = dde.nn.DeepONet( [50, 128, 128, 128], - [2, 128, 128, 128], + [dim_x, 128, 128, 128], "tanh", "Glorot normal", ) @@ -45,8 +56,8 @@ def func_ic(x, v): def periodic(x): x, t = x[:, :1], x[:, 1:] x *= 2 * np.pi - return tf.concat( - [tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1 + return concat( + [cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1 ) diff --git a/examples/operator/advection_unaligned_pideeponet_2d.py b/examples/operator/advection_unaligned_pideeponet_2d.py index f84c2826e..0398bb5c9 100644 --- a/examples/operator/advection_unaligned_pideeponet_2d.py +++ b/examples/operator/advection_unaligned_pideeponet_2d.py @@ -1,8 +1,18 @@ -"""Backend supported: tensorflow.compat.v1""" +"""Backend supported: tensorflow.compat.v1, paddle""" import deepxde as dde import matplotlib.pyplot as plt import numpy as np -from deepxde.backend import tf + +if dde.backend.backend_name == "paddle": + dim_x = 5 + sin = dde.backend.paddle.sin + cos = dde.backend.paddle.cos + concat = dde.backend.paddle.concat +else: + dim_x = 2 + sin = dde.backend.tf.sin + cos = dde.backend.tf.cos + concat = dde.backend.tf.concat # PDE def pde(x, y, v): @@ -39,7 +49,7 @@ def boundary(x, on_boundary): # Net net = dde.nn.DeepONet( [50, 128, 128, 128], - [2, 128, 128, 128], + [dim_x, 128, 128, 128], "tanh", "Glorot normal", ) @@ -48,8 +58,8 @@ def boundary(x, on_boundary): def periodic(x): x, t = x[:, :1], x[:, 1:] x *= 2 * np.pi - return tf.concat( - [tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1 + return concat( + [cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1 ) diff --git a/examples/operator/antiderivative_aligned_pideeponet.py b/examples/operator/antiderivative_aligned_pideeponet.py index fb761ba1c..6e2627611 100644 --- a/examples/operator/antiderivative_aligned_pideeponet.py +++ b/examples/operator/antiderivative_aligned_pideeponet.py @@ -1,8 +1,13 @@ -"""Backend supported: tensorflow.compat.v1, tensorflow""" +"""Backend supported: tensorflow.compat.v1, tensorflow, paddle""" import deepxde as dde import matplotlib.pyplot as plt import numpy as np -from deepxde.backend import tf + +if dde.backend.backend_name == "paddle": + transpose = dde.backend.paddle.transpose +else: + transpose = dde.backend.tf.transpose + dde.config.disable_xla_jit() @@ -37,7 +42,7 @@ def pde(x, u, v): # Hard constraint zero IC def zero_ic(inputs, outputs): - return outputs * tf.transpose(inputs[1]) + return outputs * transpose(inputs[1],[1,0]) net.apply_output_transform(zero_ic) diff --git a/examples/operator/antiderivative_unaligned_pideeponet.py b/examples/operator/antiderivative_unaligned_pideeponet.py index 10c58780a..0b6fecf68 100644 --- a/examples/operator/antiderivative_unaligned_pideeponet.py +++ b/examples/operator/antiderivative_unaligned_pideeponet.py @@ -1,4 +1,4 @@ -"""Backend supported: tensorflow.compat.v1""" +"""Backend supported: tensorflow.compat.v1, paddle""" import deepxde as dde import matplotlib.pyplot as plt import numpy as np diff --git a/examples/operator/diff_rec_aligned_pideeponet.py b/examples/operator/diff_rec_aligned_pideeponet.py index 1d35d7f51..95e64570a 100644 --- a/examples/operator/diff_rec_aligned_pideeponet.py +++ b/examples/operator/diff_rec_aligned_pideeponet.py @@ -1,4 +1,4 @@ -"""Backend supported: tensorflow.compat.v1, tensorflow""" +"""Backend supported: tensorflow.compat.v1, tensorflow, paddle""" import deepxde as dde import matplotlib.pyplot as plt import numpy as np diff --git a/examples/operator/diff_rec_unaligned_pideeponet.py b/examples/operator/diff_rec_unaligned_pideeponet.py index 56f5fec89..3d1cdc61c 100644 --- a/examples/operator/diff_rec_unaligned_pideeponet.py +++ b/examples/operator/diff_rec_unaligned_pideeponet.py @@ -1,4 +1,4 @@ -"""Backend supported: tensorflow.compat.v1""" +"""Backend supported: tensorflow.compat.v1, paddle""" import deepxde as dde import matplotlib.pyplot as plt import numpy as np diff --git a/examples/pinn_forward/Burgers.py b/examples/pinn_forward/Burgers.py index d12d72031..64ee46bb6 100644 --- a/examples/pinn_forward/Burgers.py +++ b/examples/pinn_forward/Burgers.py @@ -1,4 +1,4 @@ -"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch""" +"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" import deepxde as dde import numpy as np