From 271e15c05768efee94f6c7e6d1dd3eea72b33c22 Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Wed, 16 Aug 2023 02:44:22 +0000 Subject: [PATCH] change parameters of Lotka_Volterra example --- examples/pinn_forward/Lotka_Volterra.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/pinn_forward/Lotka_Volterra.py b/examples/pinn_forward/Lotka_Volterra.py index cfba5858a..89b790094 100644 --- a/examples/pinn_forward/Lotka_Volterra.py +++ b/examples/pinn_forward/Lotka_Volterra.py @@ -3,8 +3,10 @@ import matplotlib.pyplot as plt import numpy as np from scipy import integrate + # Import tf if using backend tensorflow.compat.v1 or tensorflow from deepxde.backend import tf + # Import torch if using backend pytorch # import torch # Import paddle if using backend paddle @@ -51,6 +53,7 @@ def ode_system(x, y): dp_t - 1 / ub * rb * (0.02 * r * ub * p * ub - 1.06 * p * ub), ] + geom = dde.geometry.TimeDomain(0, 1.0) data = dde.data.PDE(geom, ode_system, [], 3000, 2, num_test=3000) @@ -59,6 +62,7 @@ def ode_system(x, y): initializer = "Glorot normal" net = dde.nn.FNN(layer_size, activation, initializer) + # Backend tensorflow.compat.v1 or tensorflow def input_transform(t): return tf.concat( @@ -73,6 +77,8 @@ def input_transform(t): ), axis=1, ) + + # Backend pytorch # def input_transform(t): # return torch.cat( @@ -93,7 +99,7 @@ def input_transform(t): # def input_transform(t): # if t.ndim == 1: # t = t[None] -# +# # return jnp.concatenate( # [ # jnp.sin(t), @@ -101,12 +107,15 @@ def input_transform(t): # axis=1 # ) + # hard constraints: x(0) = 100, y(0) = 15 # Backend tensorflow.compat.v1 or tensorflow def output_transform(t, y): y1 = y[:, 0:1] y2 = y[:, 1:2] return tf.concat([y1 * tf.tanh(t) + 100 / ub, y2 * tf.tanh(t) + 15 / ub], axis=1) + + # Backend pytorch # def output_transform(t, y): # y1 = y[:, 0:1] @@ -131,9 +140,9 @@ def output_transform(t, y): model = dde.Model(data, net) model.compile("adam", lr=0.001) -losshistory, train_state = model.train(iterations=1000) +losshistory, train_state = model.train(iterations=50000) # Most backends except jax can have a second fine tuning of the solution -model.compile("L-BFGS", lr=0.001) +model.compile("L-BFGS") losshistory, train_state = model.train() dde.saveplot(losshistory, train_state, issave=True, isplot=True)